split subter implementation (training + inference)
This commit is contained in:
@@ -10,40 +10,51 @@
|
||||
};
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils, poetry2nix }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
poetry2nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
|
||||
pkgs = import nixpkgs{
|
||||
inherit system;
|
||||
config.allowUnfree = true;
|
||||
config.cudaSupport = true;
|
||||
};
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfree = true;
|
||||
config.cudaSupport = true;
|
||||
};
|
||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication;
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
deepsad = mkPoetryApplication {
|
||||
projectDir = self;
|
||||
preferWheels = true;
|
||||
python = pkgs.python311;
|
||||
deepsad = mkPoetryApplication {
|
||||
projectDir = self;
|
||||
preferWheels = true;
|
||||
python = pkgs.python311;
|
||||
};
|
||||
default = self.packages.${system}.deepsad;
|
||||
};
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = [ self.packages.${system}.deepsad ];
|
||||
buildInputs = with pkgs.python311Packages; [
|
||||
torch-bin
|
||||
torchvision-bin
|
||||
];
|
||||
#LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [
|
||||
#pkgs.stdenv.cc.cc
|
||||
#];
|
||||
buildInputs = with pkgs.python311Packages; [
|
||||
torch-bin
|
||||
torchvision-bin
|
||||
];
|
||||
#LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [
|
||||
#pkgs.stdenv.cc.cc
|
||||
#];
|
||||
};
|
||||
|
||||
devShells.poetry = pkgs.mkShell {
|
||||
packages = [ pkgs.poetry pkgs.python311 ];
|
||||
packages = [
|
||||
pkgs.poetry
|
||||
pkgs.python311
|
||||
];
|
||||
};
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .mnist import MNIST_Dataset
|
||||
from .elpv import ELPV_Dataset
|
||||
from .subter import SubTer_Dataset
|
||||
from .fmnist import FashionMNIST_Dataset
|
||||
from .cifar10 import CIFAR10_Dataset
|
||||
from .elpv import ELPV_Dataset
|
||||
from .fmnist import FashionMNIST_Dataset
|
||||
from .mnist import MNIST_Dataset
|
||||
from .odds import ODDSADDataset
|
||||
from .subter import SubTer_Dataset
|
||||
from .subtersplit import SubTerSplit_Dataset
|
||||
|
||||
|
||||
def load_dataset(
|
||||
@@ -24,6 +25,7 @@ def load_dataset(
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"subtersplit",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -46,6 +48,15 @@ def load_dataset(
|
||||
inference=inference,
|
||||
)
|
||||
|
||||
if dataset_name == "subtersplit":
|
||||
dataset = SubTerSplit_Dataset(
|
||||
root=data_path,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution,
|
||||
inference=inference,
|
||||
)
|
||||
|
||||
if dataset_name == "elpv":
|
||||
dataset = ELPV_Dataset(
|
||||
root=data_path,
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
from torch.utils.data import Subset
|
||||
from PIL import Image
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from base.torchvision_dataset import TorchvisionDataset
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
import numpy as np
|
||||
from base.torchvision_dataset import TorchvisionDataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Subset
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
from pathlib import Path
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
|
||||
|
||||
class SubTer_Dataset(TorchvisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
@@ -31,6 +30,7 @@ class SubTer_Dataset(TorchvisionDataset):
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = tuple([0])
|
||||
self.outlier_classes = tuple([1])
|
||||
self.inference_set = None
|
||||
|
||||
# MNIST preprocessing: feature scaling to [0, 1]
|
||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||
@@ -78,7 +78,6 @@ class SubTer_Dataset(TorchvisionDataset):
|
||||
|
||||
|
||||
class SubTerTraining(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
@@ -95,10 +94,18 @@ class SubTerTraining(VisionDataset):
|
||||
|
||||
experiments_data = []
|
||||
experiments_targets = []
|
||||
validation_files = []
|
||||
experiment_files = []
|
||||
|
||||
for experiment_file in Path(root).iterdir():
|
||||
if experiment_file.is_dir() and experiment_file.name == "validation":
|
||||
for validation_file in experiment_file.iterdir():
|
||||
if validation_file.suffix != ".npy":
|
||||
continue
|
||||
validation_files.append(experiment_file)
|
||||
if experiment_file.suffix != ".npy":
|
||||
continue
|
||||
experiment_files.append(experiment_file)
|
||||
experiment_data = np.load(experiment_file)
|
||||
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
||||
experiment_targets = (
|
||||
@@ -109,6 +116,26 @@ class SubTerTraining(VisionDataset):
|
||||
experiments_data.append(experiment_data)
|
||||
experiments_targets.append(experiment_targets)
|
||||
|
||||
filtered_validation_files = []
|
||||
for validation_file in validation_files:
|
||||
validation_file_name = validation_file.name
|
||||
file_exists_in_experiments = any(
|
||||
experiment_file.name == validation_file_name
|
||||
for experiment_file in experiment_files
|
||||
)
|
||||
if not file_exists_in_experiments:
|
||||
filtered_validation_files.append(validation_file)
|
||||
validation_files = filtered_validation_files
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
logger.info(
|
||||
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
||||
)
|
||||
logger.info(
|
||||
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
|
||||
)
|
||||
|
||||
lidar_projections = np.concatenate(experiments_data)
|
||||
smoke_presence = np.concatenate(experiments_targets)
|
||||
|
||||
@@ -166,7 +193,6 @@ class SubTerTraining(VisionDataset):
|
||||
|
||||
|
||||
class SubTerInference(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
|
||||
263
Deep-SAD-PyTorch/src/datasets/subtersplit.py
Normal file
263
Deep-SAD-PyTorch/src/datasets/subtersplit.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from base.torchvision_dataset import TorchvisionDataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
|
||||
|
||||
class SubTerSplit_Dataset(TorchvisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
ratio_known_normal: float = 0.0,
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
inference: bool = False,
|
||||
):
|
||||
super().__init__(root)
|
||||
|
||||
# Define normal and outlier classes
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = tuple([0])
|
||||
self.outlier_classes = tuple([1])
|
||||
self.inference_set = None
|
||||
|
||||
# MNIST preprocessing: feature scaling to [0, 1]
|
||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||
transform = transforms.ToTensor()
|
||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||
|
||||
if inference:
|
||||
self.inference_set = SubTerSplitInference(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
)
|
||||
else:
|
||||
# Get train set
|
||||
train_set = SubTerSplitTraining(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
train_set.targets.cpu().data.numpy(),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
np.array(semi_targets, dtype=np.int8)
|
||||
) # set respective semi-supervised labels
|
||||
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
|
||||
# Get test set
|
||||
self.test_set = SubTerSplitTraining(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
|
||||
|
||||
def split_array_into_subarrays(array, split_height, split_width):
|
||||
original_shape = array.shape
|
||||
height, width = original_shape[-2], original_shape[-1]
|
||||
assert height % split_height == 0, "The height is not divisible by the split_height"
|
||||
assert width % split_width == 0, "The width is not divisible by the split_width"
|
||||
num_splits_height = height // split_height
|
||||
num_splits_width = width // split_width
|
||||
reshaped_array = array.reshape(
|
||||
-1, num_splits_height, split_height, num_splits_width, split_width
|
||||
)
|
||||
transposed_array = reshaped_array.transpose(0, 1, 3, 2, 4)
|
||||
final_array = transposed_array.reshape(-1, split_height, split_width)
|
||||
return final_array
|
||||
|
||||
|
||||
class SubTerSplitTraining(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
train=False,
|
||||
split=0.7,
|
||||
seed=0,
|
||||
height=16,
|
||||
width=256,
|
||||
):
|
||||
super(SubTerSplitTraining, self).__init__(
|
||||
root, transforms, transform, target_transform
|
||||
)
|
||||
|
||||
experiments_data = []
|
||||
experiments_targets = []
|
||||
validation_files = []
|
||||
experiment_files = []
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
for experiment_file in Path(root).iterdir():
|
||||
if experiment_file.is_dir() and experiment_file.name == "validation":
|
||||
for validation_file in experiment_file.iterdir():
|
||||
if validation_file.suffix != ".npy":
|
||||
continue
|
||||
validation_files.append(experiment_file)
|
||||
if experiment_file.suffix != ".npy":
|
||||
continue
|
||||
experiment_files.append(experiment_file)
|
||||
experiment_data = np.load(experiment_file)
|
||||
|
||||
if (
|
||||
experiment_data.shape[1] % height != 0
|
||||
or experiment_data.shape[2] % width != 0
|
||||
):
|
||||
logger.error(
|
||||
f"Experiment {experiment_file.name} has shape {experiment_data.shape} which is not divisible by {height}x{width}"
|
||||
)
|
||||
experiment_data = split_array_into_subarrays(experiment_data, height, width)
|
||||
|
||||
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
||||
experiment_targets = (
|
||||
np.ones(experiment_data.shape[0], dtype=np.int8)
|
||||
if "smoke" in experiment_file.name
|
||||
else np.zeros(experiment_data.shape[0], dtype=np.int8)
|
||||
)
|
||||
experiments_data.append(experiment_data)
|
||||
experiments_targets.append(experiment_targets)
|
||||
|
||||
filtered_validation_files = []
|
||||
for validation_file in validation_files:
|
||||
validation_file_name = validation_file.name
|
||||
file_exists_in_experiments = any(
|
||||
experiment_file.name == validation_file_name
|
||||
for experiment_file in experiment_files
|
||||
)
|
||||
if not file_exists_in_experiments:
|
||||
filtered_validation_files.append(validation_file)
|
||||
validation_files = filtered_validation_files
|
||||
|
||||
logger.info(
|
||||
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
||||
)
|
||||
logger.info(
|
||||
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
|
||||
)
|
||||
|
||||
lidar_projections = np.concatenate(experiments_data)
|
||||
smoke_presence = np.concatenate(experiments_targets)
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
|
||||
shuffled_lidar_projections = lidar_projections[shuffled_indices]
|
||||
shuffled_smoke_presence = smoke_presence[shuffled_indices]
|
||||
|
||||
split_idx = int(split * shuffled_lidar_projections.shape[0])
|
||||
|
||||
if train:
|
||||
self.data = shuffled_lidar_projections[:split_idx]
|
||||
self.targets = shuffled_smoke_presence[:split_idx]
|
||||
|
||||
else:
|
||||
self.data = shuffled_lidar_projections[split_idx:]
|
||||
self.targets = shuffled_smoke_presence[split_idx:]
|
||||
|
||||
self.data = np.nan_to_num(self.data)
|
||||
|
||||
self.data = torch.tensor(self.data)
|
||||
self.targets = torch.tensor(self.targets, dtype=torch.int8)
|
||||
|
||||
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the MNIST class.
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target, semi_target, index)
|
||||
"""
|
||||
img, target, semi_target = (
|
||||
self.data[index],
|
||||
int(self.targets[index]),
|
||||
int(self.semi_targets[index]),
|
||||
)
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="F")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, semi_target, index
|
||||
|
||||
|
||||
class SubTerSplitInference(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
):
|
||||
super(SubTerSplitInference, self).__init__(root, transforms, transform)
|
||||
logger = logging.getLogger()
|
||||
|
||||
self.experiment_file_path = Path(root)
|
||||
|
||||
if not self.experiment_file_path.is_file():
|
||||
logger.error(
|
||||
"For inference the data path has to be a single experiment file!"
|
||||
)
|
||||
raise Exception("Inference data is not a loadable file!")
|
||||
|
||||
self.data = np.load(self.experiment_file_path)
|
||||
self.data = split_array_into_subarrays(self.data, 16, 256)
|
||||
self.data = np.nan_to_num(self.data)
|
||||
self.data = torch.tensor(self.data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the MNIST class.
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, index)
|
||||
"""
|
||||
img = self.data[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="F")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, index
|
||||
@@ -1,14 +1,14 @@
|
||||
import click
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets.main import load_dataset
|
||||
from DeepSAD import DeepSAD
|
||||
from utils.config import Config
|
||||
from utils.visualization.plot_images_grid import plot_images_grid
|
||||
from DeepSAD import DeepSAD
|
||||
from datasets.main import load_dataset
|
||||
|
||||
|
||||
################################################################################
|
||||
@@ -31,6 +31,7 @@ from datasets.main import load_dataset
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"subtersplit",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -49,6 +50,7 @@ from datasets.main import load_dataset
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_LeNet_Split",
|
||||
"fmnist_LeNet",
|
||||
"cifar10_LeNet",
|
||||
"arrhythmia_mlp",
|
||||
@@ -315,7 +317,6 @@ def main(
|
||||
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
||||
|
||||
if action == "train":
|
||||
|
||||
# Load data
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
@@ -413,7 +414,6 @@ def main(
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
||||
from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder
|
||||
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||||
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
||||
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Autoencoder
|
||||
from .mlp import MLP, MLP_Autoencoder
|
||||
from .vae import VariationalAutoencoder
|
||||
from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
|
||||
from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder
|
||||
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
||||
from .mlp import MLP, MLP_Autoencoder
|
||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
||||
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||||
from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder
|
||||
from .vae import VariationalAutoencoder
|
||||
|
||||
|
||||
def build_network(net_name, ae_net=None):
|
||||
@@ -15,6 +16,7 @@ def build_network(net_name, ae_net=None):
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_LeNet_Split",
|
||||
"mnist_DGM_M2",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
@@ -46,6 +48,9 @@ def build_network(net_name, ae_net=None):
|
||||
if net_name == "subter_LeNet":
|
||||
net = SubTer_LeNet()
|
||||
|
||||
if net_name == "subter_LeNet_Split":
|
||||
net = SubTer_LeNet_Split()
|
||||
|
||||
if net_name == "elpv_LeNet":
|
||||
net = ELPV_LeNet()
|
||||
|
||||
@@ -130,6 +135,7 @@ def build_autoencoder(net_name):
|
||||
implemented_networks = (
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_LeNet_Split",
|
||||
"mnist_LeNet",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
@@ -154,6 +160,9 @@ def build_autoencoder(net_name):
|
||||
if net_name == "subter_LeNet":
|
||||
ae_net = SubTer_LeNet_Autoencoder()
|
||||
|
||||
if net_name == "subter_LeNet_Split":
|
||||
ae_net = SubTer_LeNet_Split_Autoencoder()
|
||||
|
||||
if net_name == "elpv_LeNet":
|
||||
ae_net = ELPV_LeNet_Autoencoder()
|
||||
|
||||
|
||||
66
Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py
Normal file
66
Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from base.base_net import BaseNet
|
||||
|
||||
|
||||
class SubTer_LeNet_Split(BaseNet):
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
|
||||
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
|
||||
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.fc1 = nn.Linear(4 * 64 * 4, self.rep_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 16, 256)
|
||||
x = self.conv1(x)
|
||||
x = self.pool(F.leaky_relu(self.bn1(x)))
|
||||
x = self.conv2(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2(x)))
|
||||
x = x.view(int(x.size(0)), -1)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
class SubTer_LeNet_Split_Decoder(BaseNet):
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
# Decoder network
|
||||
self.fc3 = nn.Linear(self.rep_dim, 4 * 64 * 4, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.deconv1 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=2)
|
||||
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.deconv2 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc3(x)
|
||||
x = x.view(int(x.size(0)), 4, 4, 64)
|
||||
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
|
||||
x = self.deconv1(x)
|
||||
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
|
||||
x = self.deconv2(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SubTer_LeNet_Split_Autoencoder(BaseNet):
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.encoder = SubTer_LeNet_Split(rep_dim=rep_dim)
|
||||
self.decoder = SubTer_LeNet_Split_Decoder(rep_dim=rep_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user