add torchscan for summary and receptive field (wip)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torchscan
|
||||
|
||||
|
||||
class BaseNet(nn.Module):
|
||||
@@ -10,6 +12,7 @@ class BaseNet(nn.Module):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
|
||||
self.input_dim = None # input dimensionality, i.e. dim of the input layer
|
||||
|
||||
def forward(self, *input):
|
||||
"""
|
||||
@@ -18,9 +21,17 @@ class BaseNet(nn.Module):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def summary(self):
|
||||
def summary(self, receptive_field: bool = False):
|
||||
"""Network summary."""
|
||||
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
params = sum([np.prod(p.size()) for p in net_parameters])
|
||||
self.logger.info("Trainable parameters: {}".format(params))
|
||||
self.logger.info(self)
|
||||
# net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
# params = sum([np.prod(p.size()) for p in net_parameters])
|
||||
# self.logger.info("Trainable parameters: {}".format(params))
|
||||
# self.logger.info(self)
|
||||
if not self.input_dim:
|
||||
self.logger.warning(
|
||||
"Input dimension is not set. Please set input_dim before calling summary."
|
||||
)
|
||||
return
|
||||
self.logger.info(
|
||||
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user