53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
import torch
|
||
|
||
from torch.nn import Module
|
||
from torch.nn import init
|
||
from torch.nn.parameter import Parameter
|
||
|
||
|
||
# Acknowledgements: https://github.com/wohlert/semi-supervised-pytorch
|
||
class Standardize(Module):
|
||
"""
|
||
Applies (element-wise) standardization with trainable translation parameter μ and scale parameter σ, i.e. computes
|
||
(x - μ) / σ where '/' is applied element-wise.
|
||
|
||
Args:
|
||
in_features: size of each input sample
|
||
out_features: size of each output sample
|
||
bias: If set to False, the layer will not learn a translation parameter μ.
|
||
Default: ``True``
|
||
|
||
Attributes:
|
||
mu: the learnable translation parameter μ.
|
||
std: the learnable scale parameter σ.
|
||
"""
|
||
__constants__ = ['mu']
|
||
|
||
def __init__(self, in_features, bias=True, eps=1e-6):
|
||
super(Standardize, self).__init__()
|
||
self.in_features = in_features
|
||
self.out_features = in_features
|
||
self.eps = eps
|
||
self.std = Parameter(torch.Tensor(in_features))
|
||
if bias:
|
||
self.mu = Parameter(torch.Tensor(in_features))
|
||
else:
|
||
self.register_parameter('mu', None)
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
init.constant_(self.std, 1)
|
||
if self.mu is not None:
|
||
init.constant_(self.mu, 0)
|
||
|
||
def forward(self, x):
|
||
if self.mu is not None:
|
||
x -= self.mu
|
||
x = torch.div(x, self.std + self.eps)
|
||
return x
|
||
|
||
def extra_repr(self):
|
||
return 'in_features={}, out_features={}, bias={}'.format(
|
||
self.in_features, self.out_features, self.mu is not None
|
||
)
|