Files
mt/Deep-SAD-PyTorch/src/base/base_net.py

27 lines
797 B
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
import logging
import torch.nn as nn
import numpy as np
class BaseNet(nn.Module):
"""Base class for all neural networks."""
def __init__(self):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
def forward(self, *input):
"""
Forward pass logic
:return: Network output
"""
raise NotImplementedError
def summary(self):
"""Network summary."""
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in net_parameters])
2024-06-28 11:36:46 +02:00
self.logger.info("Trainable parameters: {}".format(params))
2024-06-28 07:42:12 +02:00
self.logger.info(self)