Files
mt/Deep-SAD-PyTorch/src/networks/layers/standard.py

53 lines
1.6 KiB
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
import torch
from torch.nn import Module
from torch.nn import init
from torch.nn.parameter import Parameter
# Acknowledgements: https://github.com/wohlert/semi-supervised-pytorch
class Standardize(Module):
"""
Applies (element-wise) standardization with trainable translation parameter μ and scale parameter σ, i.e. computes
(x - μ) / σ where '/' is applied element-wise.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn a translation parameter μ.
Default: ``True``
Attributes:
mu: the learnable translation parameter μ.
std: the learnable scale parameter σ.
"""
__constants__ = ['mu']
def __init__(self, in_features, bias=True, eps=1e-6):
super(Standardize, self).__init__()
self.in_features = in_features
self.out_features = in_features
self.eps = eps
self.std = Parameter(torch.Tensor(in_features))
if bias:
self.mu = Parameter(torch.Tensor(in_features))
else:
self.register_parameter('mu', None)
self.reset_parameters()
def reset_parameters(self):
init.constant_(self.std, 1)
if self.mu is not None:
init.constant_(self.mu, 0)
def forward(self, x):
if self.mu is not None:
x -= self.mu
x = torch.div(x, self.std + self.eps)
return x
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.mu is not None
)