Files
mt/Deep-SAD-PyTorch/src/base/base_net.py
2025-06-17 07:26:03 +02:00

36 lines
1.2 KiB
Python

import logging
import torch.nn as nn
import torchscan
class BaseNet(nn.Module):
"""Base class for all neural networks."""
def __init__(self):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
self.input_dim = None # input dimensionality, i.e. dim of the input layer
def forward(self, *input):
"""
Forward pass logic
:return: Network output
"""
raise NotImplementedError
def summary(self, receptive_field: bool = False):
"""Network summary."""
# net_parameters = filter(lambda p: p.requires_grad, self.parameters())
# params = sum([np.prod(p.size()) for p in net_parameters])
# self.logger.info("Trainable parameters: {}".format(params))
# self.logger.info(self)
if not self.input_dim:
self.logger.warning(
"Input dimension is not set. Please set input_dim before calling summary."
)
return
self.logger.info("torchscan:\n")
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)