import logging import torch.nn as nn import torchscan class BaseNet(nn.Module): """Base class for all neural networks.""" def __init__(self): super().__init__() self.logger = logging.getLogger(self.__class__.__name__) self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer self.input_dim = None # input dimensionality, i.e. dim of the input layer def forward(self, *input): """ Forward pass logic :return: Network output """ raise NotImplementedError def summary(self, receptive_field: bool = False): """Network summary.""" # net_parameters = filter(lambda p: p.requires_grad, self.parameters()) # params = sum([np.prod(p.size()) for p in net_parameters]) # self.logger.info("Trainable parameters: {}".format(params)) # self.logger.info(self) if not self.input_dim: self.logger.warning( "Input dimension is not set. Please set input_dim before calling summary." ) return self.logger.info("torchscan:\n") torchscan.summary(self, self.input_dim, receptive_field=receptive_field)