2024-06-28 07:42:12 +02:00
|
|
|
import logging
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseNet(nn.Module):
|
|
|
|
|
"""Base class for all neural networks."""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
|
|
|
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
|
|
|
|
|
|
|
|
|
|
def forward(self, *input):
|
|
|
|
|
"""
|
|
|
|
|
Forward pass logic
|
|
|
|
|
:return: Network output
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def summary(self):
|
|
|
|
|
"""Network summary."""
|
|
|
|
|
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
|
|
|
|
params = sum([np.prod(p.size()) for p in net_parameters])
|
2024-06-28 11:36:46 +02:00
|
|
|
self.logger.info("Trainable parameters: {}".format(params))
|
2024-06-28 07:42:12 +02:00
|
|
|
self.logger.info(self)
|