black formatted files before changes
This commit is contained in:
@@ -1,10 +1,22 @@
|
||||
from .main import build_network, build_autoencoder
|
||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Decoder, MNIST_LeNet_Autoencoder
|
||||
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Decoder, FashionMNIST_LeNet_Autoencoder
|
||||
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Decoder, CIFAR10_LeNet_Autoencoder
|
||||
from .fmnist_LeNet import (
|
||||
FashionMNIST_LeNet,
|
||||
FashionMNIST_LeNet_Decoder,
|
||||
FashionMNIST_LeNet_Autoencoder,
|
||||
)
|
||||
from .cifar10_LeNet import (
|
||||
CIFAR10_LeNet,
|
||||
CIFAR10_LeNet_Decoder,
|
||||
CIFAR10_LeNet_Autoencoder,
|
||||
)
|
||||
from .mlp import MLP, MLP_Decoder, MLP_Autoencoder
|
||||
from .layers.stochastic import GaussianSample
|
||||
from .layers.standard import Standardize
|
||||
from .inference.distributions import log_standard_gaussian, log_gaussian, log_standard_categorical
|
||||
from .inference.distributions import (
|
||||
log_standard_gaussian,
|
||||
log_gaussian,
|
||||
log_standard_categorical,
|
||||
)
|
||||
from .vae import VariationalAutoencoder, Encoder, Decoder
|
||||
from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
|
||||
|
||||
@@ -41,17 +41,27 @@ class CIFAR10_LeNet_Decoder(BaseNet):
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2)
|
||||
nn.init.xavier_uniform_(self.deconv1.weight, gain=nn.init.calculate_gain('leaky_relu'))
|
||||
self.deconv1 = nn.ConvTranspose2d(
|
||||
int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2
|
||||
)
|
||||
nn.init.xavier_uniform_(
|
||||
self.deconv1.weight, gain=nn.init.calculate_gain("leaky_relu")
|
||||
)
|
||||
self.bn2d4 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
|
||||
self.deconv2 = nn.ConvTranspose2d(128, 64, 5, bias=False, padding=2)
|
||||
nn.init.xavier_uniform_(self.deconv2.weight, gain=nn.init.calculate_gain('leaky_relu'))
|
||||
nn.init.xavier_uniform_(
|
||||
self.deconv2.weight, gain=nn.init.calculate_gain("leaky_relu")
|
||||
)
|
||||
self.bn2d5 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
|
||||
self.deconv3 = nn.ConvTranspose2d(64, 32, 5, bias=False, padding=2)
|
||||
nn.init.xavier_uniform_(self.deconv3.weight, gain=nn.init.calculate_gain('leaky_relu'))
|
||||
nn.init.xavier_uniform_(
|
||||
self.deconv3.weight, gain=nn.init.calculate_gain("leaky_relu")
|
||||
)
|
||||
self.bn2d6 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
|
||||
self.deconv4 = nn.ConvTranspose2d(32, 3, 5, bias=False, padding=2)
|
||||
nn.init.xavier_uniform_(self.deconv4.weight, gain=nn.init.calculate_gain('leaky_relu'))
|
||||
nn.init.xavier_uniform_(
|
||||
self.deconv4.weight, gain=nn.init.calculate_gain("leaky_relu")
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(int(x.size(0)), int(self.rep_dim / (4 * 4)), 4, 4)
|
||||
|
||||
@@ -97,7 +97,9 @@ class StackedDeepGenerativeModel(DeepGenerativeModel):
|
||||
:param features: a pre-trained M1 model of class 'VariationalAutoencoder' trained on the same dataset.
|
||||
"""
|
||||
[x_dim, y_dim, z_dim, h_dim] = dims
|
||||
super(StackedDeepGenerativeModel, self).__init__([features.z_dim, y_dim, z_dim, h_dim])
|
||||
super(StackedDeepGenerativeModel, self).__init__(
|
||||
[features.z_dim, y_dim, z_dim, h_dim]
|
||||
)
|
||||
|
||||
# Be sure to reconstruct with the same dimensions
|
||||
in_features = self.decoder.reconstruction.in_features
|
||||
|
||||
@@ -11,7 +11,7 @@ def log_standard_gaussian(x):
|
||||
:param x: point to evaluate
|
||||
:return: log N(x|0,I)
|
||||
"""
|
||||
return torch.sum(-0.5 * math.log(2 * math.pi) - x ** 2 / 2, dim=-1)
|
||||
return torch.sum(-0.5 * math.log(2 * math.pi) - x**2 / 2, dim=-1)
|
||||
|
||||
|
||||
def log_gaussian(x, mu, log_var):
|
||||
@@ -23,7 +23,11 @@ def log_gaussian(x, mu, log_var):
|
||||
:param log_var: log variance
|
||||
:return: log N(x|µ,σI)
|
||||
"""
|
||||
log_pdf = -0.5 * math.log(2 * math.pi) - log_var / 2 - (x - mu)**2 / (2 * torch.exp(log_var))
|
||||
log_pdf = (
|
||||
-0.5 * math.log(2 * math.pi)
|
||||
- log_var / 2
|
||||
- (x - mu) ** 2 / (2 * torch.exp(log_var))
|
||||
)
|
||||
return torch.sum(log_pdf, dim=-1)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ class Standardize(Module):
|
||||
mu: the learnable translation parameter μ.
|
||||
std: the learnable scale parameter σ.
|
||||
"""
|
||||
__constants__ = ['mu']
|
||||
|
||||
__constants__ = ["mu"]
|
||||
|
||||
def __init__(self, in_features, bias=True, eps=1e-6):
|
||||
super(Standardize, self).__init__()
|
||||
@@ -32,7 +33,7 @@ class Standardize(Module):
|
||||
if bias:
|
||||
self.mu = Parameter(torch.Tensor(in_features))
|
||||
else:
|
||||
self.register_parameter('mu', None)
|
||||
self.register_parameter("mu", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
@@ -47,6 +48,6 @@ class Standardize(Module):
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
return "in_features={}, out_features={}, bias={}".format(
|
||||
self.in_features, self.out_features, self.mu is not None
|
||||
)
|
||||
|
||||
@@ -9,78 +9,106 @@ from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
|
||||
def build_network(net_name, ae_net=None):
|
||||
"""Builds the neural network."""
|
||||
|
||||
implemented_networks = ('mnist_LeNet', 'mnist_DGM_M2', 'mnist_DGM_M1M2',
|
||||
'fmnist_LeNet', 'fmnist_DGM_M2', 'fmnist_DGM_M1M2',
|
||||
'cifar10_LeNet', 'cifar10_DGM_M2', 'cifar10_DGM_M1M2',
|
||||
'arrhythmia_mlp', 'cardio_mlp', 'satellite_mlp', 'satimage-2_mlp', 'shuttle_mlp',
|
||||
'thyroid_mlp',
|
||||
'arrhythmia_DGM_M2', 'cardio_DGM_M2', 'satellite_DGM_M2', 'satimage-2_DGM_M2',
|
||||
'shuttle_DGM_M2', 'thyroid_DGM_M2')
|
||||
implemented_networks = (
|
||||
"mnist_LeNet",
|
||||
"mnist_DGM_M2",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
"fmnist_DGM_M2",
|
||||
"fmnist_DGM_M1M2",
|
||||
"cifar10_LeNet",
|
||||
"cifar10_DGM_M2",
|
||||
"cifar10_DGM_M1M2",
|
||||
"arrhythmia_mlp",
|
||||
"cardio_mlp",
|
||||
"satellite_mlp",
|
||||
"satimage-2_mlp",
|
||||
"shuttle_mlp",
|
||||
"thyroid_mlp",
|
||||
"arrhythmia_DGM_M2",
|
||||
"cardio_DGM_M2",
|
||||
"satellite_DGM_M2",
|
||||
"satimage-2_DGM_M2",
|
||||
"shuttle_DGM_M2",
|
||||
"thyroid_DGM_M2",
|
||||
)
|
||||
assert net_name in implemented_networks
|
||||
|
||||
net = None
|
||||
|
||||
if net_name == 'mnist_LeNet':
|
||||
if net_name == "mnist_LeNet":
|
||||
net = MNIST_LeNet()
|
||||
|
||||
if net_name == 'mnist_DGM_M2':
|
||||
net = DeepGenerativeModel([1*28*28, 2, 32, [128, 64]], classifier_net=MNIST_LeNet)
|
||||
if net_name == "mnist_DGM_M2":
|
||||
net = DeepGenerativeModel(
|
||||
[1 * 28 * 28, 2, 32, [128, 64]], classifier_net=MNIST_LeNet
|
||||
)
|
||||
|
||||
if net_name == 'mnist_DGM_M1M2':
|
||||
net = StackedDeepGenerativeModel([1*28*28, 2, 32, [128, 64]], features=ae_net)
|
||||
if net_name == "mnist_DGM_M1M2":
|
||||
net = StackedDeepGenerativeModel(
|
||||
[1 * 28 * 28, 2, 32, [128, 64]], features=ae_net
|
||||
)
|
||||
|
||||
if net_name == 'fmnist_LeNet':
|
||||
if net_name == "fmnist_LeNet":
|
||||
net = FashionMNIST_LeNet()
|
||||
|
||||
if net_name == 'fmnist_DGM_M2':
|
||||
net = DeepGenerativeModel([1*28*28, 2, 64, [256, 128]], classifier_net=FashionMNIST_LeNet)
|
||||
if net_name == "fmnist_DGM_M2":
|
||||
net = DeepGenerativeModel(
|
||||
[1 * 28 * 28, 2, 64, [256, 128]], classifier_net=FashionMNIST_LeNet
|
||||
)
|
||||
|
||||
if net_name == 'fmnist_DGM_M1M2':
|
||||
net = StackedDeepGenerativeModel([1*28*28, 2, 64, [256, 128]], features=ae_net)
|
||||
if net_name == "fmnist_DGM_M1M2":
|
||||
net = StackedDeepGenerativeModel(
|
||||
[1 * 28 * 28, 2, 64, [256, 128]], features=ae_net
|
||||
)
|
||||
|
||||
if net_name == 'cifar10_LeNet':
|
||||
if net_name == "cifar10_LeNet":
|
||||
net = CIFAR10_LeNet()
|
||||
|
||||
if net_name == 'cifar10_DGM_M2':
|
||||
net = DeepGenerativeModel([3*32*32, 2, 128, [512, 256]], classifier_net=CIFAR10_LeNet)
|
||||
if net_name == "cifar10_DGM_M2":
|
||||
net = DeepGenerativeModel(
|
||||
[3 * 32 * 32, 2, 128, [512, 256]], classifier_net=CIFAR10_LeNet
|
||||
)
|
||||
|
||||
if net_name == 'cifar10_DGM_M1M2':
|
||||
net = StackedDeepGenerativeModel([3*32*32, 2, 128, [512, 256]], features=ae_net)
|
||||
if net_name == "cifar10_DGM_M1M2":
|
||||
net = StackedDeepGenerativeModel(
|
||||
[3 * 32 * 32, 2, 128, [512, 256]], features=ae_net
|
||||
)
|
||||
|
||||
if net_name == 'arrhythmia_mlp':
|
||||
if net_name == "arrhythmia_mlp":
|
||||
net = MLP(x_dim=274, h_dims=[128, 64], rep_dim=32, bias=False)
|
||||
|
||||
if net_name == 'cardio_mlp':
|
||||
if net_name == "cardio_mlp":
|
||||
net = MLP(x_dim=21, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'satellite_mlp':
|
||||
if net_name == "satellite_mlp":
|
||||
net = MLP(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'satimage-2_mlp':
|
||||
if net_name == "satimage-2_mlp":
|
||||
net = MLP(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'shuttle_mlp':
|
||||
if net_name == "shuttle_mlp":
|
||||
net = MLP(x_dim=9, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'thyroid_mlp':
|
||||
if net_name == "thyroid_mlp":
|
||||
net = MLP(x_dim=6, h_dims=[32, 16], rep_dim=4, bias=False)
|
||||
|
||||
if net_name == 'arrhythmia_DGM_M2':
|
||||
if net_name == "arrhythmia_DGM_M2":
|
||||
net = DeepGenerativeModel([274, 2, 32, [128, 64]])
|
||||
|
||||
if net_name == 'cardio_DGM_M2':
|
||||
if net_name == "cardio_DGM_M2":
|
||||
net = DeepGenerativeModel([21, 2, 8, [32, 16]])
|
||||
|
||||
if net_name == 'satellite_DGM_M2':
|
||||
if net_name == "satellite_DGM_M2":
|
||||
net = DeepGenerativeModel([36, 2, 8, [32, 16]])
|
||||
|
||||
if net_name == 'satimage-2_DGM_M2':
|
||||
if net_name == "satimage-2_DGM_M2":
|
||||
net = DeepGenerativeModel([36, 2, 8, [32, 16]])
|
||||
|
||||
if net_name == 'shuttle_DGM_M2':
|
||||
if net_name == "shuttle_DGM_M2":
|
||||
net = DeepGenerativeModel([9, 2, 8, [32, 16]])
|
||||
|
||||
if net_name == 'thyroid_DGM_M2':
|
||||
if net_name == "thyroid_DGM_M2":
|
||||
net = DeepGenerativeModel([6, 2, 4, [32, 16]])
|
||||
|
||||
return net
|
||||
@@ -89,50 +117,59 @@ def build_network(net_name, ae_net=None):
|
||||
def build_autoencoder(net_name):
|
||||
"""Builds the corresponding autoencoder network."""
|
||||
|
||||
implemented_networks = ('mnist_LeNet', 'mnist_DGM_M1M2',
|
||||
'fmnist_LeNet', 'fmnist_DGM_M1M2',
|
||||
'cifar10_LeNet', 'cifar10_DGM_M1M2',
|
||||
'arrhythmia_mlp', 'cardio_mlp', 'satellite_mlp', 'satimage-2_mlp', 'shuttle_mlp',
|
||||
'thyroid_mlp')
|
||||
implemented_networks = (
|
||||
"mnist_LeNet",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
"fmnist_DGM_M1M2",
|
||||
"cifar10_LeNet",
|
||||
"cifar10_DGM_M1M2",
|
||||
"arrhythmia_mlp",
|
||||
"cardio_mlp",
|
||||
"satellite_mlp",
|
||||
"satimage-2_mlp",
|
||||
"shuttle_mlp",
|
||||
"thyroid_mlp",
|
||||
)
|
||||
|
||||
assert net_name in implemented_networks
|
||||
|
||||
ae_net = None
|
||||
|
||||
if net_name == 'mnist_LeNet':
|
||||
if net_name == "mnist_LeNet":
|
||||
ae_net = MNIST_LeNet_Autoencoder()
|
||||
|
||||
if net_name == 'mnist_DGM_M1M2':
|
||||
ae_net = VariationalAutoencoder([1*28*28, 32, [128, 64]])
|
||||
if net_name == "mnist_DGM_M1M2":
|
||||
ae_net = VariationalAutoencoder([1 * 28 * 28, 32, [128, 64]])
|
||||
|
||||
if net_name == 'fmnist_LeNet':
|
||||
if net_name == "fmnist_LeNet":
|
||||
ae_net = FashionMNIST_LeNet_Autoencoder()
|
||||
|
||||
if net_name == 'fmnist_DGM_M1M2':
|
||||
ae_net = VariationalAutoencoder([1*28*28, 64, [256, 128]])
|
||||
if net_name == "fmnist_DGM_M1M2":
|
||||
ae_net = VariationalAutoencoder([1 * 28 * 28, 64, [256, 128]])
|
||||
|
||||
if net_name == 'cifar10_LeNet':
|
||||
if net_name == "cifar10_LeNet":
|
||||
ae_net = CIFAR10_LeNet_Autoencoder()
|
||||
|
||||
if net_name == 'cifar10_DGM_M1M2':
|
||||
ae_net = VariationalAutoencoder([3*32*32, 128, [512, 256]])
|
||||
if net_name == "cifar10_DGM_M1M2":
|
||||
ae_net = VariationalAutoencoder([3 * 32 * 32, 128, [512, 256]])
|
||||
|
||||
if net_name == 'arrhythmia_mlp':
|
||||
if net_name == "arrhythmia_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=274, h_dims=[128, 64], rep_dim=32, bias=False)
|
||||
|
||||
if net_name == 'cardio_mlp':
|
||||
if net_name == "cardio_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=21, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'satellite_mlp':
|
||||
if net_name == "satellite_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'satimage-2_mlp':
|
||||
if net_name == "satimage-2_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'shuttle_mlp':
|
||||
if net_name == "shuttle_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=9, h_dims=[32, 16], rep_dim=8, bias=False)
|
||||
|
||||
if net_name == 'thyroid_mlp':
|
||||
if net_name == "thyroid_mlp":
|
||||
ae_net = MLP_Autoencoder(x_dim=6, h_dims=[32, 16], rep_dim=4, bias=False)
|
||||
|
||||
return ae_net
|
||||
|
||||
@@ -12,7 +12,10 @@ class MLP(BaseNet):
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
neurons = [x_dim, *h_dims]
|
||||
layers = [Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias) for i in range(1, len(neurons))]
|
||||
layers = [
|
||||
Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias)
|
||||
for i in range(1, len(neurons))
|
||||
]
|
||||
|
||||
self.hidden = nn.ModuleList(layers)
|
||||
self.code = nn.Linear(h_dims[-1], rep_dim, bias=bias)
|
||||
@@ -32,7 +35,10 @@ class MLP_Decoder(BaseNet):
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
neurons = [rep_dim, *h_dims]
|
||||
layers = [Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias) for i in range(1, len(neurons))]
|
||||
layers = [
|
||||
Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias)
|
||||
for i in range(1, len(neurons))
|
||||
]
|
||||
|
||||
self.hidden = nn.ModuleList(layers)
|
||||
self.reconstruction = nn.Linear(h_dims[-1], x_dim, bias=bias)
|
||||
|
||||
@@ -22,7 +22,9 @@ class Encoder(nn.Module):
|
||||
|
||||
[x_dim, h_dim, z_dim] = dims
|
||||
neurons = [x_dim, *h_dim]
|
||||
linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
|
||||
linear_layers = [
|
||||
nn.Linear(neurons[i - 1], neurons[i]) for i in range(1, len(neurons))
|
||||
]
|
||||
|
||||
self.hidden = nn.ModuleList(linear_layers)
|
||||
self.sample = sample_layer(h_dim[-1], z_dim)
|
||||
@@ -48,7 +50,9 @@ class Decoder(nn.Module):
|
||||
|
||||
[z_dim, h_dim, x_dim] = dims
|
||||
neurons = [z_dim, *h_dim]
|
||||
linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
|
||||
linear_layers = [
|
||||
nn.Linear(neurons[i - 1], neurons[i]) for i in range(1, len(neurons))
|
||||
]
|
||||
|
||||
self.hidden = nn.ModuleList(linear_layers)
|
||||
self.reconstruction = nn.Linear(h_dim[-1], x_dim)
|
||||
|
||||
Reference in New Issue
Block a user