add torchscan for summary and receptive field (wip)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torchscan
|
||||
|
||||
|
||||
class BaseNet(nn.Module):
|
||||
@@ -10,6 +12,7 @@ class BaseNet(nn.Module):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
|
||||
self.input_dim = None # input dimensionality, i.e. dim of the input layer
|
||||
|
||||
def forward(self, *input):
|
||||
"""
|
||||
@@ -18,9 +21,17 @@ class BaseNet(nn.Module):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def summary(self):
|
||||
def summary(self, receptive_field: bool = False):
|
||||
"""Network summary."""
|
||||
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
params = sum([np.prod(p.size()) for p in net_parameters])
|
||||
self.logger.info("Trainable parameters: {}".format(params))
|
||||
self.logger.info(self)
|
||||
# net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
# params = sum([np.prod(p.size()) for p in net_parameters])
|
||||
# self.logger.info("Trainable parameters: {}".format(params))
|
||||
# self.logger.info(self)
|
||||
if not self.input_dim:
|
||||
self.logger.warning(
|
||||
"Input dimension is not set. Please set input_dim before calling summary."
|
||||
)
|
||||
return
|
||||
self.logger.info(
|
||||
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
|
||||
)
|
||||
|
||||
@@ -6,10 +6,10 @@ from base.base_net import BaseNet
|
||||
|
||||
|
||||
class SubTer_LeNet(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = (1, 32, 2048) # Input dimension for the network
|
||||
self.rep_dim = rep_dim
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
@@ -31,7 +31,6 @@ class SubTer_LeNet(BaseNet):
|
||||
|
||||
|
||||
class SubTer_LeNet_Decoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
@@ -56,10 +55,10 @@ class SubTer_LeNet_Decoder(BaseNet):
|
||||
|
||||
|
||||
class SubTer_LeNet_Autoencoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = (1, 32, 2048) # Input dimension for the network
|
||||
self.rep_dim = rep_dim
|
||||
self.encoder = SubTer_LeNet(rep_dim=rep_dim)
|
||||
self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim)
|
||||
|
||||
@@ -212,6 +212,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
start_time = time.time()
|
||||
idx_label_score = []
|
||||
net.eval()
|
||||
net.summary(receptive_field=True)
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
inputs, labels, semi_targets, idx, _ = data
|
||||
@@ -267,6 +268,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
c = torch.zeros(net.rep_dim, device=self.device)
|
||||
|
||||
net.eval()
|
||||
net.summary(receptive_field=True)
|
||||
with torch.no_grad():
|
||||
for data in train_loader:
|
||||
# get the inputs of the batch
|
||||
|
||||
Reference in New Issue
Block a user