added deepsad base code
This commit is contained in:
26
Deep-SAD-PyTorch/src/base/base_net.py
Normal file
26
Deep-SAD-PyTorch/src/base/base_net.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseNet(nn.Module):
|
||||
"""Base class for all neural networks."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
|
||||
|
||||
def forward(self, *input):
|
||||
"""
|
||||
Forward pass logic
|
||||
:return: Network output
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def summary(self):
|
||||
"""Network summary."""
|
||||
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
params = sum([np.prod(p.size()) for p in net_parameters])
|
||||
self.logger.info('Trainable parameters: {}'.format(params))
|
||||
self.logger.info(self)
|
||||
Reference in New Issue
Block a user