initial work for elpv and subter datasets
elpv as example dataset/implementation subter with final dataset
This commit is contained in:
163
Deep-SAD-PyTorch/src/datasets/elpv.py
Normal file
163
Deep-SAD-PyTorch/src/datasets/elpv.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from torch.utils.data import Subset
|
||||
from PIL import Image
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from base.torchvision_dataset import TorchvisionDataset
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_function_from_path(root_path, subfolder, module_name, function_name):
|
||||
root_path = Path(root_path)
|
||||
module_path = root_path / subfolder / f"{module_name}.py"
|
||||
|
||||
if not module_path.exists():
|
||||
raise FileNotFoundError(f"The module {module_path} does not exist.")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(module_path))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if not hasattr(module, function_name):
|
||||
raise AttributeError(
|
||||
f"The function {function_name} does not exist in the module {module_name}."
|
||||
)
|
||||
|
||||
return getattr(module, function_name)
|
||||
|
||||
|
||||
class ELPV_Dataset(TorchvisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
ratio_known_normal: float = 0.0,
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
):
|
||||
super().__init__(root)
|
||||
|
||||
# Define normal and outlier classes
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = tuple([0])
|
||||
self.outlier_classes = tuple([1])
|
||||
|
||||
# MNIST preprocessing: feature scaling to [0, 1]
|
||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||
transform = transforms.ToTensor()
|
||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||
|
||||
# Get train set
|
||||
train_set = MyELPV(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
train_set.targets.cpu().data.numpy(),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
semi_targets
|
||||
) # set respective semi-supervised labels
|
||||
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
|
||||
# Get test set
|
||||
self.test_set = MyELPV(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
|
||||
|
||||
class MyELPV(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
train=False,
|
||||
split=0.7,
|
||||
seed=0,
|
||||
):
|
||||
super(MyELPV, self).__init__(root, transforms, transform, target_transform)
|
||||
|
||||
load_dataset = load_function_from_path(
|
||||
root, "utils", "elpv_reader", "load_dataset"
|
||||
)
|
||||
|
||||
images, proba, _ = load_dataset()
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
shuffled_indices = np.random.permutation(images.shape[0])
|
||||
shuffled_data = images[shuffled_indices]
|
||||
shuffled_proba = proba[shuffled_indices]
|
||||
|
||||
split_idx = int(split * shuffled_data.shape[0])
|
||||
|
||||
if train:
|
||||
self.data = shuffled_data[:split_idx]
|
||||
self.targets = shuffled_proba[:split_idx]
|
||||
|
||||
else:
|
||||
self.data = shuffled_data[split_idx:]
|
||||
self.targets = shuffled_proba[split_idx:]
|
||||
|
||||
self.data = torch.tensor(self.data)
|
||||
self.targets[self.targets > 0] = 1
|
||||
self.targets = torch.tensor(self.targets, dtype=torch.int64)
|
||||
|
||||
self.semi_targets = torch.zeros_like(self.targets)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the MNIST class.
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target, semi_target, index)
|
||||
"""
|
||||
img, target, semi_target = (
|
||||
self.data[index],
|
||||
int(self.targets[index]),
|
||||
int(self.semi_targets[index]),
|
||||
)
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, semi_target, index
|
||||
@@ -1,4 +1,6 @@
|
||||
from .mnist import MNIST_Dataset
|
||||
from .elpv import ELPV_Dataset
|
||||
from .subter import SubTer_Dataset
|
||||
from .fmnist import FashionMNIST_Dataset
|
||||
from .cifar10 import CIFAR10_Dataset
|
||||
from .odds import ODDSADDataset
|
||||
@@ -19,6 +21,8 @@ def load_dataset(
|
||||
|
||||
implemented_datasets = (
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -32,6 +36,22 @@ def load_dataset(
|
||||
|
||||
dataset = None
|
||||
|
||||
if dataset_name == "subter":
|
||||
dataset = SubTer_Dataset(
|
||||
root=data_path,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution,
|
||||
)
|
||||
|
||||
if dataset_name == "elpv":
|
||||
dataset = ELPV_Dataset(
|
||||
root=data_path,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution,
|
||||
)
|
||||
|
||||
if dataset_name == "mnist":
|
||||
dataset = MNIST_Dataset(
|
||||
root=data_path,
|
||||
|
||||
155
Deep-SAD-PyTorch/src/datasets/subter.py
Normal file
155
Deep-SAD-PyTorch/src/datasets/subter.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from torch.utils.data import Subset
|
||||
from PIL import Image
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from base.torchvision_dataset import TorchvisionDataset
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SubTer_Dataset(TorchvisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
ratio_known_normal: float = 0.0,
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
):
|
||||
super().__init__(root)
|
||||
|
||||
# Define normal and outlier classes
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = tuple([0])
|
||||
self.outlier_classes = tuple([1])
|
||||
|
||||
# MNIST preprocessing: feature scaling to [0, 1]
|
||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||
transform = transforms.ToTensor()
|
||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||
|
||||
# Get train set
|
||||
train_set = MySubTer(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
train_set.targets.cpu().data.numpy(),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
np.array(semi_targets, dtype=np.int8)
|
||||
) # set respective semi-supervised labels
|
||||
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
|
||||
# Get test set
|
||||
self.test_set = MySubTer(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
|
||||
|
||||
class MySubTer(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
train=False,
|
||||
split=0.7,
|
||||
seed=0,
|
||||
):
|
||||
super(MySubTer, self).__init__(root, transforms, transform, target_transform)
|
||||
|
||||
experiments_data = []
|
||||
experiments_targets = []
|
||||
|
||||
for experiment_file in Path(root).iterdir():
|
||||
if experiment_file.suffix != ".npy":
|
||||
continue
|
||||
experiment_data = np.load(experiment_file)
|
||||
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
||||
experiment_targets = (
|
||||
np.ones(experiment_data.shape[0], dtype=np.int8)
|
||||
if "smoke" in experiment_file.name
|
||||
else np.zeros(experiment_data.shape[0], dtype=np.int8)
|
||||
)
|
||||
experiments_data.append(experiment_data)
|
||||
experiments_targets.append(experiment_targets)
|
||||
|
||||
lidar_projections = np.concatenate(experiments_data)
|
||||
smoke_presence = np.concatenate(experiments_targets)
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
|
||||
shuffled_lidar_projections = lidar_projections[shuffled_indices]
|
||||
shuffled_smoke_presence = smoke_presence[shuffled_indices]
|
||||
|
||||
split_idx = int(split * shuffled_lidar_projections.shape[0])
|
||||
|
||||
if train:
|
||||
self.data = shuffled_lidar_projections[:split_idx]
|
||||
self.targets = shuffled_smoke_presence[:split_idx]
|
||||
|
||||
else:
|
||||
self.data = shuffled_lidar_projections[split_idx:]
|
||||
self.targets = shuffled_smoke_presence[split_idx:]
|
||||
|
||||
self.data = np.nan_to_num(self.data)
|
||||
|
||||
self.data = torch.tensor(self.data)
|
||||
self.targets = torch.tensor(self.targets, dtype=torch.int8)
|
||||
|
||||
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the MNIST class.
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target, semi_target, index)
|
||||
"""
|
||||
img, target, semi_target = (
|
||||
self.data[index],
|
||||
int(self.targets[index]),
|
||||
int(self.semi_targets[index]),
|
||||
)
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="F")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, semi_target, index
|
||||
@@ -19,6 +19,8 @@ from datasets.main import load_dataset
|
||||
type=click.Choice(
|
||||
[
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -35,6 +37,8 @@ from datasets.main import load_dataset
|
||||
type=click.Choice(
|
||||
[
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"fmnist_LeNet",
|
||||
"cifar10_LeNet",
|
||||
"arrhythmia_mlp",
|
||||
@@ -109,7 +113,7 @@ from datasets.main import load_dataset
|
||||
@click.option(
|
||||
"--lr_milestone",
|
||||
type=int,
|
||||
default=0,
|
||||
default=[0],
|
||||
multiple=True,
|
||||
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
|
||||
)
|
||||
@@ -149,7 +153,7 @@ from datasets.main import load_dataset
|
||||
@click.option(
|
||||
"--ae_lr_milestone",
|
||||
type=int,
|
||||
default=0,
|
||||
default=[0],
|
||||
multiple=True,
|
||||
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
|
||||
)
|
||||
@@ -393,9 +397,9 @@ def main(
|
||||
np.argsort(scores[labels == 0])
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10"):
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||
|
||||
if dataset_name in ("mnist", "fmnist"):
|
||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1)
|
||||
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(
|
||||
|
||||
74
Deep-SAD-PyTorch/src/networks/elpv_LeNet.py
Normal file
74
Deep-SAD-PyTorch/src/networks/elpv_LeNet.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from base.base_net import BaseNet
|
||||
|
||||
|
||||
class ELPV_LeNet(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
|
||||
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
|
||||
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.fc1 = nn.Linear(4 * 75 * 75, self.rep_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 300, 300)
|
||||
x = self.conv1(x)
|
||||
x = self.pool(F.leaky_relu(self.bn1(x)))
|
||||
x = self.conv2(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2(x)))
|
||||
x = x.view(int(x.size(0)), -1)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
class ELPV_LeNet_Decoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
# Decoder network
|
||||
self.fc3 = nn.Linear(self.rep_dim, 2888, bias=False)
|
||||
self.bn1d2 = nn.BatchNorm1d(2888, eps=1e-04, affine=False)
|
||||
self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2)
|
||||
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3)
|
||||
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn1d2(self.fc3(x))
|
||||
x = x.view(int(x.size(0)), 2, 38, 38)
|
||||
x = F.interpolate(F.leaky_relu(x), scale_factor=2)
|
||||
x = self.deconv1(x)
|
||||
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
|
||||
x = self.deconv2(x)
|
||||
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
|
||||
x = self.deconv3(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class ELPV_LeNet_Autoencoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=256):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.encoder = ELPV_LeNet(rep_dim=rep_dim)
|
||||
self.decoder = ELPV_LeNet_Decoder(rep_dim=rep_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
@@ -1,4 +1,6 @@
|
||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
||||
from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder
|
||||
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||||
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
||||
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Autoencoder
|
||||
from .mlp import MLP, MLP_Autoencoder
|
||||
@@ -11,6 +13,8 @@ def build_network(net_name, ae_net=None):
|
||||
|
||||
implemented_networks = (
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"mnist_DGM_M2",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
@@ -39,6 +43,12 @@ def build_network(net_name, ae_net=None):
|
||||
if net_name == "mnist_LeNet":
|
||||
net = MNIST_LeNet()
|
||||
|
||||
if net_name == "subter_LeNet":
|
||||
net = SubTer_LeNet()
|
||||
|
||||
if net_name == "elpv_LeNet":
|
||||
net = ELPV_LeNet()
|
||||
|
||||
if net_name == "mnist_DGM_M2":
|
||||
net = DeepGenerativeModel(
|
||||
[1 * 28 * 28, 2, 32, [128, 64]], classifier_net=MNIST_LeNet
|
||||
@@ -118,6 +128,8 @@ def build_autoencoder(net_name):
|
||||
"""Builds the corresponding autoencoder network."""
|
||||
|
||||
implemented_networks = (
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"mnist_LeNet",
|
||||
"mnist_DGM_M1M2",
|
||||
"fmnist_LeNet",
|
||||
@@ -139,6 +151,12 @@ def build_autoencoder(net_name):
|
||||
if net_name == "mnist_LeNet":
|
||||
ae_net = MNIST_LeNet_Autoencoder()
|
||||
|
||||
if net_name == "subter_LeNet":
|
||||
ae_net = SubTer_LeNet_Autoencoder()
|
||||
|
||||
if net_name == "elpv_LeNet":
|
||||
ae_net = ELPV_LeNet_Autoencoder()
|
||||
|
||||
if net_name == "mnist_DGM_M1M2":
|
||||
ae_net = VariationalAutoencoder([1 * 28 * 28, 32, [128, 64]])
|
||||
|
||||
|
||||
70
Deep-SAD-PyTorch/src/networks/subter_LeNet.py
Normal file
70
Deep-SAD-PyTorch/src/networks/subter_LeNet.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from base.base_net import BaseNet
|
||||
|
||||
|
||||
class SubTer_LeNet(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
|
||||
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
|
||||
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.fc1 = nn.Linear(4 * 512 * 8, self.rep_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 32, 2048)
|
||||
x = self.conv1(x)
|
||||
x = self.pool(F.leaky_relu(self.bn1(x)))
|
||||
x = self.conv2(x)
|
||||
x = self.pool(F.leaky_relu(self.bn2(x)))
|
||||
x = x.view(int(x.size(0)), -1)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
class SubTer_LeNet_Decoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
# Decoder network
|
||||
self.fc3 = nn.Linear(self.rep_dim, 4 * 512 * 8, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||
self.deconv1 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=2)
|
||||
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||
self.deconv2 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc3(x)
|
||||
x = x.view(int(x.size(0)), 4, 8, 512)
|
||||
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
|
||||
x = self.deconv1(x)
|
||||
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
|
||||
x = self.deconv2(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SubTer_LeNet_Autoencoder(BaseNet):
|
||||
|
||||
def __init__(self, rep_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.rep_dim = rep_dim
|
||||
self.encoder = SubTer_LeNet(rep_dim=rep_dim)
|
||||
self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
35
Deep-SAD-PyTorch/src/onnx_export.py
Normal file
35
Deep-SAD-PyTorch/src/onnx_export.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.onnx
|
||||
from networks.mnist_LeNet import MNIST_LeNet_Autoencoder
|
||||
|
||||
|
||||
def export_model_to_onnx(model, filepath, input_shape=(1, 1, 28, 28)):
|
||||
model.eval() # Set the model to evaluation mode
|
||||
dummy_input = torch.randn(input_shape) # Create a dummy input tensor
|
||||
torch.onnx.export(
|
||||
model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
filepath, # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=11, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=["input"], # the model's input names
|
||||
output_names=["output"], # the model's output names
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size"}, # variable length axes
|
||||
"output": {0: "batch_size"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the autoencoder model
|
||||
autoencoder = MNIST_LeNet_Autoencoder(rep_dim=32)
|
||||
|
||||
# Define the file path where the ONNX model will be saved
|
||||
onnx_file_path = "mnist_lenet_autoencoder.onnx"
|
||||
|
||||
# Export the model
|
||||
export_model_to_onnx(autoencoder, onnx_file_path)
|
||||
|
||||
print(f"Model has been exported to {onnx_file_path}")
|
||||
Reference in New Issue
Block a user