implemented inference
This commit is contained in:
@@ -86,6 +86,18 @@ class DeepSAD(object):
|
|||||||
self.results["train_time"] = self.trainer.train_time
|
self.results["train_time"] = self.trainer.train_time
|
||||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||||
|
):
|
||||||
|
"""Tests the Deep SAD model on the test data."""
|
||||||
|
|
||||||
|
if self.trainer is None:
|
||||||
|
self.trainer = DeepSADTrainer(
|
||||||
|
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.trainer.infer(dataset, self.net)
|
||||||
|
|
||||||
def test(
|
def test(
|
||||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -14,19 +14,39 @@ class TorchvisionDataset(BaseADDataset):
|
|||||||
shuffle_train=True,
|
shuffle_train=True,
|
||||||
shuffle_test=False,
|
shuffle_test=False,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
) -> (DataLoader, DataLoader):
|
) -> (DataLoader, DataLoader, DataLoader):
|
||||||
train_loader = DataLoader(
|
train_loader = (
|
||||||
dataset=self.train_set,
|
DataLoader(
|
||||||
batch_size=batch_size,
|
dataset=self.train_set,
|
||||||
shuffle=shuffle_train,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
shuffle=shuffle_train,
|
||||||
drop_last=True,
|
num_workers=num_workers,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
if self.train_set is not None
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = (
|
||||||
dataset=self.test_set,
|
DataLoader(
|
||||||
batch_size=batch_size,
|
dataset=self.test_set,
|
||||||
shuffle=shuffle_test,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
shuffle=shuffle_test,
|
||||||
drop_last=False,
|
num_workers=num_workers,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
if self.test_set is not None
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
return train_loader, test_loader
|
|
||||||
|
inference_loader = (
|
||||||
|
DataLoader(
|
||||||
|
dataset=self.inference_set,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
if self.inference_set is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return train_loader, test_loader, inference_loader
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ class IsoForest(object):
|
|||||||
"""Tests the Isolation Forest model on the test data."""
|
"""Tests the Isolation Forest model on the test data."""
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
# Get data from loader
|
# Get data from loader
|
||||||
idx_label_score = []
|
idx_label_score = []
|
||||||
|
|||||||
@@ -108,7 +108,9 @@ class KDE(object):
|
|||||||
"""Tests the Kernel Density Estimation model on the test data."""
|
"""Tests the Kernel Density Estimation model on the test data."""
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
# Get data from loader
|
# Get data from loader
|
||||||
idx_label_score = []
|
idx_label_score = []
|
||||||
|
|||||||
@@ -77,7 +77,9 @@ class OCSVM(object):
|
|||||||
best_auc = 0.0
|
best_auc = 0.0
|
||||||
|
|
||||||
# Sample hold-out set from test set
|
# Sample hold-out set from test set
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
X_test = ()
|
X_test = ()
|
||||||
labels = []
|
labels = []
|
||||||
@@ -163,7 +165,9 @@ class OCSVM(object):
|
|||||||
"""Tests the OC-SVM model on the test data."""
|
"""Tests the OC-SVM model on the test data."""
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
# Get data from loader
|
# Get data from loader
|
||||||
idx_label_score = []
|
idx_label_score = []
|
||||||
|
|||||||
@@ -91,7 +91,9 @@ class SSAD(object):
|
|||||||
best_auc = 0.0
|
best_auc = 0.0
|
||||||
|
|
||||||
# Sample hold-out set from test set
|
# Sample hold-out set from test set
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
X_test = ()
|
X_test = ()
|
||||||
labels = []
|
labels = []
|
||||||
@@ -190,7 +192,9 @@ class SSAD(object):
|
|||||||
"""Tests the SSAD model on the test data."""
|
"""Tests the SSAD model on the test data."""
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
_, test_loader, _ = dataset.loaders(
|
||||||
|
batch_size=128, num_workers=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
# Get data from loader
|
# Get data from loader
|
||||||
idx_label_score = []
|
idx_label_score = []
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ def load_dataset(
|
|||||||
ratio_known_outlier: float = 0.0,
|
ratio_known_outlier: float = 0.0,
|
||||||
ratio_pollution: float = 0.0,
|
ratio_pollution: float = 0.0,
|
||||||
random_state=None,
|
random_state=None,
|
||||||
|
inference: bool = False,
|
||||||
):
|
):
|
||||||
"""Loads the dataset."""
|
"""Loads the dataset."""
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ def load_dataset(
|
|||||||
ratio_known_normal=ratio_known_normal,
|
ratio_known_normal=ratio_known_normal,
|
||||||
ratio_known_outlier=ratio_known_outlier,
|
ratio_known_outlier=ratio_known_outlier,
|
||||||
ratio_pollution=ratio_pollution,
|
ratio_pollution=ratio_pollution,
|
||||||
|
inference=inference,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_name == "elpv":
|
if dataset_name == "elpv":
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from base.torchvision_dataset import TorchvisionDataset
|
|||||||
from .preprocessing import create_semisupervised_setting
|
from .preprocessing import create_semisupervised_setting
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import random
|
import random
|
||||||
@@ -22,6 +23,7 @@ class SubTer_Dataset(TorchvisionDataset):
|
|||||||
ratio_known_normal: float = 0.0,
|
ratio_known_normal: float = 0.0,
|
||||||
ratio_known_outlier: float = 0.0,
|
ratio_known_outlier: float = 0.0,
|
||||||
ratio_pollution: float = 0.0,
|
ratio_pollution: float = 0.0,
|
||||||
|
inference: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
|
|
||||||
@@ -35,41 +37,47 @@ class SubTer_Dataset(TorchvisionDataset):
|
|||||||
transform = transforms.ToTensor()
|
transform = transforms.ToTensor()
|
||||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||||
|
|
||||||
# Get train set
|
if inference:
|
||||||
train_set = MySubTer(
|
self.inference_set = SubTerInference(
|
||||||
root=self.root,
|
root=self.root,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
target_transform=target_transform,
|
)
|
||||||
train=True,
|
else:
|
||||||
)
|
# Get train set
|
||||||
|
train_set = SubTerTraining(
|
||||||
|
root=self.root,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Create semi-supervised setting
|
# Create semi-supervised setting
|
||||||
idx, _, semi_targets = create_semisupervised_setting(
|
idx, _, semi_targets = create_semisupervised_setting(
|
||||||
train_set.targets.cpu().data.numpy(),
|
train_set.targets.cpu().data.numpy(),
|
||||||
self.normal_classes,
|
self.normal_classes,
|
||||||
self.outlier_classes,
|
self.outlier_classes,
|
||||||
self.outlier_classes,
|
self.outlier_classes,
|
||||||
ratio_known_normal,
|
ratio_known_normal,
|
||||||
ratio_known_outlier,
|
ratio_known_outlier,
|
||||||
ratio_pollution,
|
ratio_pollution,
|
||||||
)
|
)
|
||||||
train_set.semi_targets[idx] = torch.tensor(
|
train_set.semi_targets[idx] = torch.tensor(
|
||||||
np.array(semi_targets, dtype=np.int8)
|
np.array(semi_targets, dtype=np.int8)
|
||||||
) # set respective semi-supervised labels
|
) # set respective semi-supervised labels
|
||||||
|
|
||||||
# Subset train_set to semi-supervised setup
|
# Subset train_set to semi-supervised setup
|
||||||
self.train_set = Subset(train_set, idx)
|
self.train_set = Subset(train_set, idx)
|
||||||
|
|
||||||
# Get test set
|
# Get test set
|
||||||
self.test_set = MySubTer(
|
self.test_set = SubTerTraining(
|
||||||
root=self.root,
|
root=self.root,
|
||||||
train=False,
|
train=False,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
target_transform=target_transform,
|
target_transform=target_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MySubTer(VisionDataset):
|
class SubTerTraining(VisionDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -81,7 +89,9 @@ class MySubTer(VisionDataset):
|
|||||||
split=0.7,
|
split=0.7,
|
||||||
seed=0,
|
seed=0,
|
||||||
):
|
):
|
||||||
super(MySubTer, self).__init__(root, transforms, transform, target_transform)
|
super(SubTerTraining, self).__init__(
|
||||||
|
root, transforms, transform, target_transform
|
||||||
|
)
|
||||||
|
|
||||||
experiments_data = []
|
experiments_data = []
|
||||||
experiments_targets = []
|
experiments_targets = []
|
||||||
@@ -153,3 +163,49 @@ class MySubTer(VisionDataset):
|
|||||||
target = self.target_transform(target)
|
target = self.target_transform(target)
|
||||||
|
|
||||||
return img, target, semi_target, index
|
return img, target, semi_target, index
|
||||||
|
|
||||||
|
|
||||||
|
class SubTerInference(VisionDataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
super(SubTerInference, self).__init__(root, transforms, transform)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
self.experiment_file_path = Path(root)
|
||||||
|
|
||||||
|
if not self.experiment_file_path.is_file():
|
||||||
|
logger.error(
|
||||||
|
"For inference the data path has to be a single experiment file!"
|
||||||
|
)
|
||||||
|
raise Exception("Inference data is not a loadable file!")
|
||||||
|
|
||||||
|
self.data = np.load(self.experiment_file_path)
|
||||||
|
self.data = np.nan_to_num(self.data)
|
||||||
|
self.data = torch.tensor(self.data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""Override the original method of the MNIST class.
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (image, index)
|
||||||
|
"""
|
||||||
|
img = self.data[index]
|
||||||
|
|
||||||
|
# doing this so that it is consistent with all other datasets
|
||||||
|
# to return a PIL Image
|
||||||
|
img = Image.fromarray(img.numpy(), mode="F")
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, index
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import torch
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from utils.config import Config
|
from utils.config import Config
|
||||||
from utils.visualization.plot_images_grid import plot_images_grid
|
from utils.visualization.plot_images_grid import plot_images_grid
|
||||||
@@ -14,6 +15,15 @@ from datasets.main import load_dataset
|
|||||||
# Settings
|
# Settings
|
||||||
################################################################################
|
################################################################################
|
||||||
@click.command()
|
@click.command()
|
||||||
|
@click.argument(
|
||||||
|
"action",
|
||||||
|
type=click.Choice(
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"infer",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"dataset_name",
|
"dataset_name",
|
||||||
type=click.Choice(
|
type=click.Choice(
|
||||||
@@ -203,6 +213,7 @@ from datasets.main import load_dataset
|
|||||||
"If > 1, the specified number of outlier classes will be sampled at random.",
|
"If > 1, the specified number of outlier classes will be sampled at random.",
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
|
action,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
net_name,
|
net_name,
|
||||||
xp_path,
|
xp_path,
|
||||||
@@ -303,138 +314,194 @@ def main(
|
|||||||
logger.info("Number of threads: %d" % num_threads)
|
logger.info("Number of threads: %d" % num_threads)
|
||||||
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
||||||
|
|
||||||
# Load data
|
if action == "train":
|
||||||
dataset = load_dataset(
|
|
||||||
dataset_name,
|
|
||||||
data_path,
|
|
||||||
normal_class,
|
|
||||||
known_outlier_class,
|
|
||||||
n_known_outlier_classes,
|
|
||||||
ratio_known_normal,
|
|
||||||
ratio_known_outlier,
|
|
||||||
ratio_pollution,
|
|
||||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
|
||||||
)
|
|
||||||
# Log random sample of known anomaly classes if more than 1 class
|
|
||||||
if n_known_outlier_classes > 1:
|
|
||||||
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
|
||||||
|
|
||||||
# Initialize DeepSAD model and set neural network phi
|
# Load data
|
||||||
deepSAD = DeepSAD(cfg.settings["eta"])
|
dataset = load_dataset(
|
||||||
deepSAD.set_network(net_name)
|
dataset_name,
|
||||||
|
data_path,
|
||||||
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
normal_class,
|
||||||
if load_model:
|
known_outlier_class,
|
||||||
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
n_known_outlier_classes,
|
||||||
logger.info("Loading model from %s." % load_model)
|
ratio_known_normal,
|
||||||
|
ratio_known_outlier,
|
||||||
logger.info("Pretraining: %s" % pretrain)
|
ratio_pollution,
|
||||||
if pretrain:
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||||
# Log pretraining details
|
|
||||||
logger.info("Pretraining optimizer: %s" % cfg.settings["ae_optimizer_name"])
|
|
||||||
logger.info("Pretraining learning rate: %g" % cfg.settings["ae_lr"])
|
|
||||||
logger.info("Pretraining epochs: %d" % cfg.settings["ae_n_epochs"])
|
|
||||||
logger.info(
|
|
||||||
"Pretraining learning rate scheduler milestones: %s"
|
|
||||||
% (cfg.settings["ae_lr_milestone"],)
|
|
||||||
)
|
)
|
||||||
logger.info("Pretraining batch size: %d" % cfg.settings["ae_batch_size"])
|
# Log random sample of known anomaly classes if more than 1 class
|
||||||
logger.info("Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"])
|
if n_known_outlier_classes > 1:
|
||||||
|
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
||||||
|
|
||||||
# Pretrain model on dataset (via autoencoder)
|
# Initialize DeepSAD model and set neural network phi
|
||||||
deepSAD.pretrain(
|
deepSAD = DeepSAD(cfg.settings["eta"])
|
||||||
|
deepSAD.set_network(net_name)
|
||||||
|
|
||||||
|
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
||||||
|
if load_model:
|
||||||
|
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
||||||
|
logger.info("Loading model from %s." % load_model)
|
||||||
|
|
||||||
|
logger.info("Pretraining: %s" % pretrain)
|
||||||
|
if pretrain:
|
||||||
|
# Log pretraining details
|
||||||
|
logger.info("Pretraining optimizer: %s" % cfg.settings["ae_optimizer_name"])
|
||||||
|
logger.info("Pretraining learning rate: %g" % cfg.settings["ae_lr"])
|
||||||
|
logger.info("Pretraining epochs: %d" % cfg.settings["ae_n_epochs"])
|
||||||
|
logger.info(
|
||||||
|
"Pretraining learning rate scheduler milestones: %s"
|
||||||
|
% (cfg.settings["ae_lr_milestone"],)
|
||||||
|
)
|
||||||
|
logger.info("Pretraining batch size: %d" % cfg.settings["ae_batch_size"])
|
||||||
|
logger.info(
|
||||||
|
"Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pretrain model on dataset (via autoencoder)
|
||||||
|
deepSAD.pretrain(
|
||||||
|
dataset,
|
||||||
|
optimizer_name=cfg.settings["ae_optimizer_name"],
|
||||||
|
lr=cfg.settings["ae_lr"],
|
||||||
|
n_epochs=cfg.settings["ae_n_epochs"],
|
||||||
|
lr_milestones=cfg.settings["ae_lr_milestone"],
|
||||||
|
batch_size=cfg.settings["ae_batch_size"],
|
||||||
|
weight_decay=cfg.settings["ae_weight_decay"],
|
||||||
|
device=device,
|
||||||
|
n_jobs_dataloader=n_jobs_dataloader,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save pretraining results
|
||||||
|
deepSAD.save_ae_results(export_json=xp_path + "/ae_results.json")
|
||||||
|
|
||||||
|
# Log training details
|
||||||
|
logger.info("Training optimizer: %s" % cfg.settings["optimizer_name"])
|
||||||
|
logger.info("Training learning rate: %g" % cfg.settings["lr"])
|
||||||
|
logger.info("Training epochs: %d" % cfg.settings["n_epochs"])
|
||||||
|
logger.info(
|
||||||
|
"Training learning rate scheduler milestones: %s"
|
||||||
|
% (cfg.settings["lr_milestone"],)
|
||||||
|
)
|
||||||
|
logger.info("Training batch size: %d" % cfg.settings["batch_size"])
|
||||||
|
logger.info("Training weight decay: %g" % cfg.settings["weight_decay"])
|
||||||
|
|
||||||
|
# Train model on dataset
|
||||||
|
deepSAD.train(
|
||||||
dataset,
|
dataset,
|
||||||
optimizer_name=cfg.settings["ae_optimizer_name"],
|
optimizer_name=cfg.settings["optimizer_name"],
|
||||||
lr=cfg.settings["ae_lr"],
|
lr=cfg.settings["lr"],
|
||||||
n_epochs=cfg.settings["ae_n_epochs"],
|
n_epochs=cfg.settings["n_epochs"],
|
||||||
lr_milestones=cfg.settings["ae_lr_milestone"],
|
lr_milestones=cfg.settings["lr_milestone"],
|
||||||
batch_size=cfg.settings["ae_batch_size"],
|
batch_size=cfg.settings["batch_size"],
|
||||||
weight_decay=cfg.settings["ae_weight_decay"],
|
weight_decay=cfg.settings["weight_decay"],
|
||||||
device=device,
|
device=device,
|
||||||
n_jobs_dataloader=n_jobs_dataloader,
|
n_jobs_dataloader=n_jobs_dataloader,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save pretraining results
|
# Test model
|
||||||
deepSAD.save_ae_results(export_json=xp_path + "/ae_results.json")
|
deepSAD.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)
|
||||||
|
|
||||||
# Log training details
|
# Save results, model, and configuration
|
||||||
logger.info("Training optimizer: %s" % cfg.settings["optimizer_name"])
|
deepSAD.save_results(export_json=xp_path + "/results.json")
|
||||||
logger.info("Training learning rate: %g" % cfg.settings["lr"])
|
deepSAD.save_model(export_model=xp_path + "/model.tar")
|
||||||
logger.info("Training epochs: %d" % cfg.settings["n_epochs"])
|
cfg.save_config(export_json=xp_path + "/config.json")
|
||||||
logger.info(
|
|
||||||
"Training learning rate scheduler milestones: %s"
|
|
||||||
% (cfg.settings["lr_milestone"],)
|
|
||||||
)
|
|
||||||
logger.info("Training batch size: %d" % cfg.settings["batch_size"])
|
|
||||||
logger.info("Training weight decay: %g" % cfg.settings["weight_decay"])
|
|
||||||
|
|
||||||
# Train model on dataset
|
# Plot most anomalous and most normal test samples
|
||||||
deepSAD.train(
|
indices, labels, scores = zip(*deepSAD.results["test_scores"])
|
||||||
dataset,
|
indices, labels, scores = np.array(indices), np.array(labels), np.array(scores)
|
||||||
optimizer_name=cfg.settings["optimizer_name"],
|
idx_all_sorted = indices[np.argsort(scores)] # from lowest to highest score
|
||||||
lr=cfg.settings["lr"],
|
idx_normal_sorted = indices[labels == 0][
|
||||||
n_epochs=cfg.settings["n_epochs"],
|
np.argsort(scores[labels == 0])
|
||||||
lr_milestones=cfg.settings["lr_milestone"],
|
] # from lowest to highest score
|
||||||
batch_size=cfg.settings["batch_size"],
|
|
||||||
weight_decay=cfg.settings["weight_decay"],
|
|
||||||
device=device,
|
|
||||||
n_jobs_dataloader=n_jobs_dataloader,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test model
|
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||||
deepSAD.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)
|
|
||||||
|
|
||||||
# Save results, model, and configuration
|
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||||
deepSAD.save_results(export_json=xp_path + "/results.json")
|
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||||
deepSAD.save_model(export_model=xp_path + "/model.tar")
|
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
||||||
cfg.save_config(export_json=xp_path + "/config.json")
|
1
|
||||||
|
|
||||||
# Plot most anomalous and most normal test samples
|
|
||||||
indices, labels, scores = zip(*deepSAD.results["test_scores"])
|
|
||||||
indices, labels, scores = np.array(indices), np.array(labels), np.array(scores)
|
|
||||||
idx_all_sorted = indices[np.argsort(scores)] # from lowest to highest score
|
|
||||||
idx_normal_sorted = indices[labels == 0][
|
|
||||||
np.argsort(scores[labels == 0])
|
|
||||||
] # from lowest to highest score
|
|
||||||
|
|
||||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
|
||||||
|
|
||||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
|
||||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
|
||||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1)
|
|
||||||
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(
|
|
||||||
1
|
|
||||||
)
|
|
||||||
X_normal_high = dataset.test_set.data[
|
|
||||||
idx_normal_sorted[-32:], ...
|
|
||||||
].unsqueeze(1)
|
|
||||||
|
|
||||||
if dataset_name == "cifar10":
|
|
||||||
X_all_low = torch.tensor(
|
|
||||||
np.transpose(
|
|
||||||
dataset.test_set.data[idx_all_sorted[:32], ...], (0, 3, 1, 2)
|
|
||||||
)
|
)
|
||||||
)
|
X_normal_low = dataset.test_set.data[
|
||||||
X_all_high = torch.tensor(
|
idx_normal_sorted[:32], ...
|
||||||
np.transpose(
|
].unsqueeze(1)
|
||||||
dataset.test_set.data[idx_all_sorted[-32:], ...], (0, 3, 1, 2)
|
X_normal_high = dataset.test_set.data[
|
||||||
)
|
idx_normal_sorted[-32:], ...
|
||||||
)
|
].unsqueeze(1)
|
||||||
X_normal_low = torch.tensor(
|
|
||||||
np.transpose(
|
|
||||||
dataset.test_set.data[idx_normal_sorted[:32], ...], (0, 3, 1, 2)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
X_normal_high = torch.tensor(
|
|
||||||
np.transpose(
|
|
||||||
dataset.test_set.data[idx_normal_sorted[-32:], ...], (0, 3, 1, 2)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2)
|
if dataset_name == "cifar10":
|
||||||
plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2)
|
X_all_low = torch.tensor(
|
||||||
plot_images_grid(X_normal_low, export_img=xp_path + "/normals_low", padding=2)
|
np.transpose(
|
||||||
plot_images_grid(X_normal_high, export_img=xp_path + "/normals_high", padding=2)
|
dataset.test_set.data[idx_all_sorted[:32], ...], (0, 3, 1, 2)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
X_all_high = torch.tensor(
|
||||||
|
np.transpose(
|
||||||
|
dataset.test_set.data[idx_all_sorted[-32:], ...], (0, 3, 1, 2)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
X_normal_low = torch.tensor(
|
||||||
|
np.transpose(
|
||||||
|
dataset.test_set.data[idx_normal_sorted[:32], ...], (0, 3, 1, 2)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
X_normal_high = torch.tensor(
|
||||||
|
np.transpose(
|
||||||
|
dataset.test_set.data[idx_normal_sorted[-32:], ...],
|
||||||
|
(0, 3, 1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2)
|
||||||
|
plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2)
|
||||||
|
plot_images_grid(
|
||||||
|
X_normal_low, export_img=xp_path + "/normals_low", padding=2
|
||||||
|
)
|
||||||
|
plot_images_grid(
|
||||||
|
X_normal_high, export_img=xp_path + "/normals_high", padding=2
|
||||||
|
)
|
||||||
|
elif action == "infer":
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
data_path,
|
||||||
|
normal_class,
|
||||||
|
known_outlier_class,
|
||||||
|
n_known_outlier_classes,
|
||||||
|
ratio_known_normal,
|
||||||
|
ratio_known_outlier,
|
||||||
|
ratio_pollution,
|
||||||
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||||
|
inference=True,
|
||||||
|
)
|
||||||
|
# Log random sample of known anomaly classes if more than 1 class
|
||||||
|
if n_known_outlier_classes > 1:
|
||||||
|
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
||||||
|
|
||||||
|
# Initialize DeepSAD model and set neural network phi
|
||||||
|
deepSAD = DeepSAD(cfg.settings["eta"])
|
||||||
|
deepSAD.set_network(net_name)
|
||||||
|
|
||||||
|
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
||||||
|
if not load_model:
|
||||||
|
logger.error(
|
||||||
|
"For inference mode a model has to be loaded! Pass the --load_model option with the model path!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
||||||
|
logger.info("Loading model from %s." % load_model)
|
||||||
|
|
||||||
|
inference_results = deepSAD.inference(
|
||||||
|
dataset, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||||
|
)
|
||||||
|
inference_results_path = (
|
||||||
|
Path(xp_path) / "inference" / Path(dataset.root).with_suffix(".npy").stem
|
||||||
|
)
|
||||||
|
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
np.save(inference_results_path, inference_results, fix_imports=False)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unknown action: {action}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get train data loader
|
# Get train data loader
|
||||||
train_loader, _ = dataset.loaders(
|
train_loader, _, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,11 +130,49 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
|
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# Get test data loader
|
||||||
|
_, _, inference_loader = dataset.loaders(
|
||||||
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set device for network
|
||||||
|
net = net.to(self.device)
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
logger.info("Starting inference...")
|
||||||
|
n_batches = 0
|
||||||
|
start_time = time.time()
|
||||||
|
scores = []
|
||||||
|
net.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in inference_loader:
|
||||||
|
inputs, idx = data
|
||||||
|
|
||||||
|
inputs = inputs.to(self.device)
|
||||||
|
idx = idx.to(self.device)
|
||||||
|
|
||||||
|
outputs = net(inputs)
|
||||||
|
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
||||||
|
scores += dist.cpu().data.numpy().tolist()
|
||||||
|
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
self.inference_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
||||||
|
logger.info("Finished inference.")
|
||||||
|
|
||||||
|
return np.array(scores)
|
||||||
|
|
||||||
def test(self, dataset: BaseADDataset, net: BaseNet):
|
def test(self, dataset: BaseADDataset, net: BaseNet):
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get test data loader
|
# Get test data loader
|
||||||
_, test_loader = dataset.loaders(
|
_, test_loader, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class SemiDeepGenerativeTrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get train data loader
|
# Get train data loader
|
||||||
train_loader, _ = dataset.loaders(
|
train_loader, _, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class SemiDeepGenerativeTrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get test data loader
|
# Get test data loader
|
||||||
_, test_loader = dataset.loaders(
|
_, test_loader, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class AETrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get train data loader
|
# Get train data loader
|
||||||
train_loader, _ = dataset.loaders(
|
train_loader, _, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ class AETrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get test data loader
|
# Get test data loader
|
||||||
_, test_loader = dataset.loaders(
|
_, test_loader, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class VAETrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get train data loader
|
# Get train data loader
|
||||||
train_loader, _ = dataset.loaders(
|
train_loader, _, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ class VAETrainer(BaseTrainer):
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# Get test data loader
|
# Get test data loader
|
||||||
_, test_loader = dataset.loaders(
|
_, test_loader, _ = dataset.loaders(
|
||||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user