black formatted files before changes

This commit is contained in:
Jan Kowalczyk
2024-06-28 11:36:46 +02:00
parent d33c6b1e16
commit 71f9662022
40 changed files with 2938 additions and 1260 deletions

View File

@@ -1,10 +1,22 @@
from .main import build_network, build_autoencoder
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Decoder, MNIST_LeNet_Autoencoder
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Decoder, FashionMNIST_LeNet_Autoencoder
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Decoder, CIFAR10_LeNet_Autoencoder
from .fmnist_LeNet import (
FashionMNIST_LeNet,
FashionMNIST_LeNet_Decoder,
FashionMNIST_LeNet_Autoencoder,
)
from .cifar10_LeNet import (
CIFAR10_LeNet,
CIFAR10_LeNet_Decoder,
CIFAR10_LeNet_Autoencoder,
)
from .mlp import MLP, MLP_Decoder, MLP_Autoencoder
from .layers.stochastic import GaussianSample
from .layers.standard import Standardize
from .inference.distributions import log_standard_gaussian, log_gaussian, log_standard_categorical
from .inference.distributions import (
log_standard_gaussian,
log_gaussian,
log_standard_categorical,
)
from .vae import VariationalAutoencoder, Encoder, Decoder
from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel

View File

@@ -41,17 +41,27 @@ class CIFAR10_LeNet_Decoder(BaseNet):
self.rep_dim = rep_dim
self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2)
nn.init.xavier_uniform_(self.deconv1.weight, gain=nn.init.calculate_gain('leaky_relu'))
self.deconv1 = nn.ConvTranspose2d(
int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2
)
nn.init.xavier_uniform_(
self.deconv1.weight, gain=nn.init.calculate_gain("leaky_relu")
)
self.bn2d4 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
self.deconv2 = nn.ConvTranspose2d(128, 64, 5, bias=False, padding=2)
nn.init.xavier_uniform_(self.deconv2.weight, gain=nn.init.calculate_gain('leaky_relu'))
nn.init.xavier_uniform_(
self.deconv2.weight, gain=nn.init.calculate_gain("leaky_relu")
)
self.bn2d5 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
self.deconv3 = nn.ConvTranspose2d(64, 32, 5, bias=False, padding=2)
nn.init.xavier_uniform_(self.deconv3.weight, gain=nn.init.calculate_gain('leaky_relu'))
nn.init.xavier_uniform_(
self.deconv3.weight, gain=nn.init.calculate_gain("leaky_relu")
)
self.bn2d6 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
self.deconv4 = nn.ConvTranspose2d(32, 3, 5, bias=False, padding=2)
nn.init.xavier_uniform_(self.deconv4.weight, gain=nn.init.calculate_gain('leaky_relu'))
nn.init.xavier_uniform_(
self.deconv4.weight, gain=nn.init.calculate_gain("leaky_relu")
)
def forward(self, x):
x = x.view(int(x.size(0)), int(self.rep_dim / (4 * 4)), 4, 4)

View File

@@ -97,7 +97,9 @@ class StackedDeepGenerativeModel(DeepGenerativeModel):
:param features: a pre-trained M1 model of class 'VariationalAutoencoder' trained on the same dataset.
"""
[x_dim, y_dim, z_dim, h_dim] = dims
super(StackedDeepGenerativeModel, self).__init__([features.z_dim, y_dim, z_dim, h_dim])
super(StackedDeepGenerativeModel, self).__init__(
[features.z_dim, y_dim, z_dim, h_dim]
)
# Be sure to reconstruct with the same dimensions
in_features = self.decoder.reconstruction.in_features

View File

@@ -11,7 +11,7 @@ def log_standard_gaussian(x):
:param x: point to evaluate
:return: log N(x|0,I)
"""
return torch.sum(-0.5 * math.log(2 * math.pi) - x ** 2 / 2, dim=-1)
return torch.sum(-0.5 * math.log(2 * math.pi) - x**2 / 2, dim=-1)
def log_gaussian(x, mu, log_var):
@@ -23,7 +23,11 @@ def log_gaussian(x, mu, log_var):
:param log_var: log variance
:return: log N(x|µ,σI)
"""
log_pdf = -0.5 * math.log(2 * math.pi) - log_var / 2 - (x - mu)**2 / (2 * torch.exp(log_var))
log_pdf = (
-0.5 * math.log(2 * math.pi)
- log_var / 2
- (x - mu) ** 2 / (2 * torch.exp(log_var))
)
return torch.sum(log_pdf, dim=-1)

View File

@@ -21,7 +21,8 @@ class Standardize(Module):
mu: the learnable translation parameter μ.
std: the learnable scale parameter σ.
"""
__constants__ = ['mu']
__constants__ = ["mu"]
def __init__(self, in_features, bias=True, eps=1e-6):
super(Standardize, self).__init__()
@@ -32,7 +33,7 @@ class Standardize(Module):
if bias:
self.mu = Parameter(torch.Tensor(in_features))
else:
self.register_parameter('mu', None)
self.register_parameter("mu", None)
self.reset_parameters()
def reset_parameters(self):
@@ -47,6 +48,6 @@ class Standardize(Module):
return x
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.mu is not None
)

View File

@@ -9,78 +9,106 @@ from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
def build_network(net_name, ae_net=None):
"""Builds the neural network."""
implemented_networks = ('mnist_LeNet', 'mnist_DGM_M2', 'mnist_DGM_M1M2',
'fmnist_LeNet', 'fmnist_DGM_M2', 'fmnist_DGM_M1M2',
'cifar10_LeNet', 'cifar10_DGM_M2', 'cifar10_DGM_M1M2',
'arrhythmia_mlp', 'cardio_mlp', 'satellite_mlp', 'satimage-2_mlp', 'shuttle_mlp',
'thyroid_mlp',
'arrhythmia_DGM_M2', 'cardio_DGM_M2', 'satellite_DGM_M2', 'satimage-2_DGM_M2',
'shuttle_DGM_M2', 'thyroid_DGM_M2')
implemented_networks = (
"mnist_LeNet",
"mnist_DGM_M2",
"mnist_DGM_M1M2",
"fmnist_LeNet",
"fmnist_DGM_M2",
"fmnist_DGM_M1M2",
"cifar10_LeNet",
"cifar10_DGM_M2",
"cifar10_DGM_M1M2",
"arrhythmia_mlp",
"cardio_mlp",
"satellite_mlp",
"satimage-2_mlp",
"shuttle_mlp",
"thyroid_mlp",
"arrhythmia_DGM_M2",
"cardio_DGM_M2",
"satellite_DGM_M2",
"satimage-2_DGM_M2",
"shuttle_DGM_M2",
"thyroid_DGM_M2",
)
assert net_name in implemented_networks
net = None
if net_name == 'mnist_LeNet':
if net_name == "mnist_LeNet":
net = MNIST_LeNet()
if net_name == 'mnist_DGM_M2':
net = DeepGenerativeModel([1*28*28, 2, 32, [128, 64]], classifier_net=MNIST_LeNet)
if net_name == "mnist_DGM_M2":
net = DeepGenerativeModel(
[1 * 28 * 28, 2, 32, [128, 64]], classifier_net=MNIST_LeNet
)
if net_name == 'mnist_DGM_M1M2':
net = StackedDeepGenerativeModel([1*28*28, 2, 32, [128, 64]], features=ae_net)
if net_name == "mnist_DGM_M1M2":
net = StackedDeepGenerativeModel(
[1 * 28 * 28, 2, 32, [128, 64]], features=ae_net
)
if net_name == 'fmnist_LeNet':
if net_name == "fmnist_LeNet":
net = FashionMNIST_LeNet()
if net_name == 'fmnist_DGM_M2':
net = DeepGenerativeModel([1*28*28, 2, 64, [256, 128]], classifier_net=FashionMNIST_LeNet)
if net_name == "fmnist_DGM_M2":
net = DeepGenerativeModel(
[1 * 28 * 28, 2, 64, [256, 128]], classifier_net=FashionMNIST_LeNet
)
if net_name == 'fmnist_DGM_M1M2':
net = StackedDeepGenerativeModel([1*28*28, 2, 64, [256, 128]], features=ae_net)
if net_name == "fmnist_DGM_M1M2":
net = StackedDeepGenerativeModel(
[1 * 28 * 28, 2, 64, [256, 128]], features=ae_net
)
if net_name == 'cifar10_LeNet':
if net_name == "cifar10_LeNet":
net = CIFAR10_LeNet()
if net_name == 'cifar10_DGM_M2':
net = DeepGenerativeModel([3*32*32, 2, 128, [512, 256]], classifier_net=CIFAR10_LeNet)
if net_name == "cifar10_DGM_M2":
net = DeepGenerativeModel(
[3 * 32 * 32, 2, 128, [512, 256]], classifier_net=CIFAR10_LeNet
)
if net_name == 'cifar10_DGM_M1M2':
net = StackedDeepGenerativeModel([3*32*32, 2, 128, [512, 256]], features=ae_net)
if net_name == "cifar10_DGM_M1M2":
net = StackedDeepGenerativeModel(
[3 * 32 * 32, 2, 128, [512, 256]], features=ae_net
)
if net_name == 'arrhythmia_mlp':
if net_name == "arrhythmia_mlp":
net = MLP(x_dim=274, h_dims=[128, 64], rep_dim=32, bias=False)
if net_name == 'cardio_mlp':
if net_name == "cardio_mlp":
net = MLP(x_dim=21, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'satellite_mlp':
if net_name == "satellite_mlp":
net = MLP(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'satimage-2_mlp':
if net_name == "satimage-2_mlp":
net = MLP(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'shuttle_mlp':
if net_name == "shuttle_mlp":
net = MLP(x_dim=9, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'thyroid_mlp':
if net_name == "thyroid_mlp":
net = MLP(x_dim=6, h_dims=[32, 16], rep_dim=4, bias=False)
if net_name == 'arrhythmia_DGM_M2':
if net_name == "arrhythmia_DGM_M2":
net = DeepGenerativeModel([274, 2, 32, [128, 64]])
if net_name == 'cardio_DGM_M2':
if net_name == "cardio_DGM_M2":
net = DeepGenerativeModel([21, 2, 8, [32, 16]])
if net_name == 'satellite_DGM_M2':
if net_name == "satellite_DGM_M2":
net = DeepGenerativeModel([36, 2, 8, [32, 16]])
if net_name == 'satimage-2_DGM_M2':
if net_name == "satimage-2_DGM_M2":
net = DeepGenerativeModel([36, 2, 8, [32, 16]])
if net_name == 'shuttle_DGM_M2':
if net_name == "shuttle_DGM_M2":
net = DeepGenerativeModel([9, 2, 8, [32, 16]])
if net_name == 'thyroid_DGM_M2':
if net_name == "thyroid_DGM_M2":
net = DeepGenerativeModel([6, 2, 4, [32, 16]])
return net
@@ -89,50 +117,59 @@ def build_network(net_name, ae_net=None):
def build_autoencoder(net_name):
"""Builds the corresponding autoencoder network."""
implemented_networks = ('mnist_LeNet', 'mnist_DGM_M1M2',
'fmnist_LeNet', 'fmnist_DGM_M1M2',
'cifar10_LeNet', 'cifar10_DGM_M1M2',
'arrhythmia_mlp', 'cardio_mlp', 'satellite_mlp', 'satimage-2_mlp', 'shuttle_mlp',
'thyroid_mlp')
implemented_networks = (
"mnist_LeNet",
"mnist_DGM_M1M2",
"fmnist_LeNet",
"fmnist_DGM_M1M2",
"cifar10_LeNet",
"cifar10_DGM_M1M2",
"arrhythmia_mlp",
"cardio_mlp",
"satellite_mlp",
"satimage-2_mlp",
"shuttle_mlp",
"thyroid_mlp",
)
assert net_name in implemented_networks
ae_net = None
if net_name == 'mnist_LeNet':
if net_name == "mnist_LeNet":
ae_net = MNIST_LeNet_Autoencoder()
if net_name == 'mnist_DGM_M1M2':
ae_net = VariationalAutoencoder([1*28*28, 32, [128, 64]])
if net_name == "mnist_DGM_M1M2":
ae_net = VariationalAutoencoder([1 * 28 * 28, 32, [128, 64]])
if net_name == 'fmnist_LeNet':
if net_name == "fmnist_LeNet":
ae_net = FashionMNIST_LeNet_Autoencoder()
if net_name == 'fmnist_DGM_M1M2':
ae_net = VariationalAutoencoder([1*28*28, 64, [256, 128]])
if net_name == "fmnist_DGM_M1M2":
ae_net = VariationalAutoencoder([1 * 28 * 28, 64, [256, 128]])
if net_name == 'cifar10_LeNet':
if net_name == "cifar10_LeNet":
ae_net = CIFAR10_LeNet_Autoencoder()
if net_name == 'cifar10_DGM_M1M2':
ae_net = VariationalAutoencoder([3*32*32, 128, [512, 256]])
if net_name == "cifar10_DGM_M1M2":
ae_net = VariationalAutoencoder([3 * 32 * 32, 128, [512, 256]])
if net_name == 'arrhythmia_mlp':
if net_name == "arrhythmia_mlp":
ae_net = MLP_Autoencoder(x_dim=274, h_dims=[128, 64], rep_dim=32, bias=False)
if net_name == 'cardio_mlp':
if net_name == "cardio_mlp":
ae_net = MLP_Autoencoder(x_dim=21, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'satellite_mlp':
if net_name == "satellite_mlp":
ae_net = MLP_Autoencoder(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'satimage-2_mlp':
if net_name == "satimage-2_mlp":
ae_net = MLP_Autoencoder(x_dim=36, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'shuttle_mlp':
if net_name == "shuttle_mlp":
ae_net = MLP_Autoencoder(x_dim=9, h_dims=[32, 16], rep_dim=8, bias=False)
if net_name == 'thyroid_mlp':
if net_name == "thyroid_mlp":
ae_net = MLP_Autoencoder(x_dim=6, h_dims=[32, 16], rep_dim=4, bias=False)
return ae_net

View File

@@ -12,7 +12,10 @@ class MLP(BaseNet):
self.rep_dim = rep_dim
neurons = [x_dim, *h_dims]
layers = [Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias) for i in range(1, len(neurons))]
layers = [
Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias)
for i in range(1, len(neurons))
]
self.hidden = nn.ModuleList(layers)
self.code = nn.Linear(h_dims[-1], rep_dim, bias=bias)
@@ -32,7 +35,10 @@ class MLP_Decoder(BaseNet):
self.rep_dim = rep_dim
neurons = [rep_dim, *h_dims]
layers = [Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias) for i in range(1, len(neurons))]
layers = [
Linear_BN_leakyReLU(neurons[i - 1], neurons[i], bias=bias)
for i in range(1, len(neurons))
]
self.hidden = nn.ModuleList(layers)
self.reconstruction = nn.Linear(h_dims[-1], x_dim, bias=bias)

View File

@@ -22,7 +22,9 @@ class Encoder(nn.Module):
[x_dim, h_dim, z_dim] = dims
neurons = [x_dim, *h_dim]
linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
linear_layers = [
nn.Linear(neurons[i - 1], neurons[i]) for i in range(1, len(neurons))
]
self.hidden = nn.ModuleList(linear_layers)
self.sample = sample_layer(h_dim[-1], z_dim)
@@ -48,7 +50,9 @@ class Decoder(nn.Module):
[z_dim, h_dim, x_dim] = dims
neurons = [z_dim, *h_dim]
linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
linear_layers = [
nn.Linear(neurons[i - 1], neurons[i]) for i in range(1, len(neurons))
]
self.hidden = nn.ModuleList(linear_layers)
self.reconstruction = nn.Linear(h_dim[-1], x_dim)