implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -86,6 +86,18 @@ class DeepSAD(object):
self.results["train_time"] = self.trainer.train_time self.results["train_time"] = self.trainer.train_time
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
def inference(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
):
"""Tests the Deep SAD model on the test data."""
if self.trainer is None:
self.trainer = DeepSADTrainer(
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
)
return self.trainer.infer(dataset, self.net)
def test( def test(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0 self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
): ):

View File

@@ -14,19 +14,39 @@ class TorchvisionDataset(BaseADDataset):
shuffle_train=True, shuffle_train=True,
shuffle_test=False, shuffle_test=False,
num_workers: int = 0, num_workers: int = 0,
) -> (DataLoader, DataLoader): ) -> (DataLoader, DataLoader, DataLoader):
train_loader = DataLoader( train_loader = (
DataLoader(
dataset=self.train_set, dataset=self.train_set,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle_train, shuffle=shuffle_train,
num_workers=num_workers, num_workers=num_workers,
drop_last=True, drop_last=True,
) )
test_loader = DataLoader( if self.train_set is not None
else None
)
test_loader = (
DataLoader(
dataset=self.test_set, dataset=self.test_set,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle_test, shuffle=shuffle_test,
num_workers=num_workers, num_workers=num_workers,
drop_last=False, drop_last=False,
) )
return train_loader, test_loader if self.test_set is not None
else None
)
inference_loader = (
DataLoader(
dataset=self.inference_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=False,
)
if self.inference_set is not None
else None
)
return train_loader, test_loader, inference_loader

View File

@@ -96,7 +96,9 @@ class IsoForest(object):
"""Tests the Isolation Forest model on the test data.""" """Tests the Isolation Forest model on the test data."""
logger = logging.getLogger() logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader # Get data from loader
idx_label_score = [] idx_label_score = []

View File

@@ -108,7 +108,9 @@ class KDE(object):
"""Tests the Kernel Density Estimation model on the test data.""" """Tests the Kernel Density Estimation model on the test data."""
logger = logging.getLogger() logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader # Get data from loader
idx_label_score = [] idx_label_score = []

View File

@@ -77,7 +77,9 @@ class OCSVM(object):
best_auc = 0.0 best_auc = 0.0
# Sample hold-out set from test set # Sample hold-out set from test set
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
X_test = () X_test = ()
labels = [] labels = []
@@ -163,7 +165,9 @@ class OCSVM(object):
"""Tests the OC-SVM model on the test data.""" """Tests the OC-SVM model on the test data."""
logger = logging.getLogger() logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader # Get data from loader
idx_label_score = [] idx_label_score = []

View File

@@ -91,7 +91,9 @@ class SSAD(object):
best_auc = 0.0 best_auc = 0.0
# Sample hold-out set from test set # Sample hold-out set from test set
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
X_test = () X_test = ()
labels = [] labels = []
@@ -190,7 +192,9 @@ class SSAD(object):
"""Tests the SSAD model on the test data.""" """Tests the SSAD model on the test data."""
logger = logging.getLogger() logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader) _, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader # Get data from loader
idx_label_score = [] idx_label_score = []

View File

@@ -16,6 +16,7 @@ def load_dataset(
ratio_known_outlier: float = 0.0, ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0, ratio_pollution: float = 0.0,
random_state=None, random_state=None,
inference: bool = False,
): ):
"""Loads the dataset.""" """Loads the dataset."""
@@ -42,6 +43,7 @@ def load_dataset(
ratio_known_normal=ratio_known_normal, ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier, ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution, ratio_pollution=ratio_pollution,
inference=inference,
) )
if dataset_name == "elpv": if dataset_name == "elpv":

View File

@@ -6,6 +6,7 @@ from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import create_semisupervised_setting from .preprocessing import create_semisupervised_setting
from typing import Callable, Optional from typing import Callable, Optional
import logging
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import random import random
@@ -22,6 +23,7 @@ class SubTer_Dataset(TorchvisionDataset):
ratio_known_normal: float = 0.0, ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0, ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0, ratio_pollution: float = 0.0,
inference: bool = False,
): ):
super().__init__(root) super().__init__(root)
@@ -35,8 +37,14 @@ class SubTer_Dataset(TorchvisionDataset):
transform = transforms.ToTensor() transform = transforms.ToTensor()
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
if inference:
self.inference_set = SubTerInference(
root=self.root,
transform=transform,
)
else:
# Get train set # Get train set
train_set = MySubTer( train_set = SubTerTraining(
root=self.root, root=self.root,
transform=transform, transform=transform,
target_transform=target_transform, target_transform=target_transform,
@@ -61,7 +69,7 @@ class SubTer_Dataset(TorchvisionDataset):
self.train_set = Subset(train_set, idx) self.train_set = Subset(train_set, idx)
# Get test set # Get test set
self.test_set = MySubTer( self.test_set = SubTerTraining(
root=self.root, root=self.root,
train=False, train=False,
transform=transform, transform=transform,
@@ -69,7 +77,7 @@ class SubTer_Dataset(TorchvisionDataset):
) )
class MySubTer(VisionDataset): class SubTerTraining(VisionDataset):
def __init__( def __init__(
self, self,
@@ -81,7 +89,9 @@ class MySubTer(VisionDataset):
split=0.7, split=0.7,
seed=0, seed=0,
): ):
super(MySubTer, self).__init__(root, transforms, transform, target_transform) super(SubTerTraining, self).__init__(
root, transforms, transform, target_transform
)
experiments_data = [] experiments_data = []
experiments_targets = [] experiments_targets = []
@@ -153,3 +163,49 @@ class MySubTer(VisionDataset):
target = self.target_transform(target) target = self.target_transform(target)
return img, target, semi_target, index return img, target, semi_target, index
class SubTerInference(VisionDataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
):
super(SubTerInference, self).__init__(root, transforms, transform)
logger = logging.getLogger()
self.experiment_file_path = Path(root)
if not self.experiment_file_path.is_file():
logger.error(
"For inference the data path has to be a single experiment file!"
)
raise Exception("Inference data is not a loadable file!")
self.data = np.load(self.experiment_file_path)
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
tuple: (image, index)
"""
img = self.data[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="F")
if self.transform is not None:
img = self.transform(img)
return img, index

View File

@@ -3,6 +3,7 @@ import torch
import logging import logging
import random import random
import numpy as np import numpy as np
from pathlib import Path
from utils.config import Config from utils.config import Config
from utils.visualization.plot_images_grid import plot_images_grid from utils.visualization.plot_images_grid import plot_images_grid
@@ -14,6 +15,15 @@ from datasets.main import load_dataset
# Settings # Settings
################################################################################ ################################################################################
@click.command() @click.command()
@click.argument(
"action",
type=click.Choice(
[
"train",
"infer",
]
),
)
@click.argument( @click.argument(
"dataset_name", "dataset_name",
type=click.Choice( type=click.Choice(
@@ -203,6 +213,7 @@ from datasets.main import load_dataset
"If > 1, the specified number of outlier classes will be sampled at random.", "If > 1, the specified number of outlier classes will be sampled at random.",
) )
def main( def main(
action,
dataset_name, dataset_name,
net_name, net_name,
xp_path, xp_path,
@@ -303,6 +314,8 @@ def main(
logger.info("Number of threads: %d" % num_threads) logger.info("Number of threads: %d" % num_threads)
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader) logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
if action == "train":
# Load data # Load data
dataset = load_dataset( dataset = load_dataset(
dataset_name, dataset_name,
@@ -339,7 +352,9 @@ def main(
% (cfg.settings["ae_lr_milestone"],) % (cfg.settings["ae_lr_milestone"],)
) )
logger.info("Pretraining batch size: %d" % cfg.settings["ae_batch_size"]) logger.info("Pretraining batch size: %d" % cfg.settings["ae_batch_size"])
logger.info("Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]) logger.info(
"Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]
)
# Pretrain model on dataset (via autoencoder) # Pretrain model on dataset (via autoencoder)
deepSAD.pretrain( deepSAD.pretrain(
@@ -401,10 +416,12 @@ def main(
if dataset_name in ("mnist", "fmnist", "elpv"): if dataset_name in ("mnist", "fmnist", "elpv"):
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1) X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1) X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(
1 1
) )
X_normal_low = dataset.test_set.data[
idx_normal_sorted[:32], ...
].unsqueeze(1)
X_normal_high = dataset.test_set.data[ X_normal_high = dataset.test_set.data[
idx_normal_sorted[-32:], ... idx_normal_sorted[-32:], ...
].unsqueeze(1) ].unsqueeze(1)
@@ -427,14 +444,64 @@ def main(
) )
X_normal_high = torch.tensor( X_normal_high = torch.tensor(
np.transpose( np.transpose(
dataset.test_set.data[idx_normal_sorted[-32:], ...], (0, 3, 1, 2) dataset.test_set.data[idx_normal_sorted[-32:], ...],
(0, 3, 1, 2),
) )
) )
plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2) plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2)
plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2) plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2)
plot_images_grid(X_normal_low, export_img=xp_path + "/normals_low", padding=2) plot_images_grid(
plot_images_grid(X_normal_high, export_img=xp_path + "/normals_high", padding=2) X_normal_low, export_img=xp_path + "/normals_low", padding=2
)
plot_images_grid(
X_normal_high, export_img=xp_path + "/normals_high", padding=2
)
elif action == "infer":
dataset = load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
random_state=np.random.RandomState(cfg.settings["seed"]),
inference=True,
)
# Log random sample of known anomaly classes if more than 1 class
if n_known_outlier_classes > 1:
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
# Initialize DeepSAD model and set neural network phi
deepSAD = DeepSAD(cfg.settings["eta"])
deepSAD.set_network(net_name)
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
if not load_model:
logger.error(
"For inference mode a model has to be loaded! Pass the --load_model option with the model path!"
)
return
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
logger.info("Loading model from %s." % load_model)
inference_results = deepSAD.inference(
dataset, device=device, n_jobs_dataloader=n_jobs_dataloader
)
inference_results_path = (
Path(xp_path) / "inference" / Path(dataset.root).with_suffix(".npy").stem
)
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
np.save(inference_results_path, inference_results, fix_imports=False)
logger.info(
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
)
else:
logger.error(f"Unknown action: {action}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -54,7 +54,7 @@ class DeepSADTrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get train data loader # Get train data loader
train_loader, _ = dataset.loaders( train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )
@@ -130,11 +130,49 @@ class DeepSADTrainer(BaseTrainer):
return net return net
def infer(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get test data loader
_, _, inference_loader = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
# Set device for network
net = net.to(self.device)
# Testing
logger.info("Starting inference...")
n_batches = 0
start_time = time.time()
scores = []
net.eval()
with torch.no_grad():
for data in inference_loader:
inputs, idx = data
inputs = inputs.to(self.device)
idx = idx.to(self.device)
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
scores += dist.cpu().data.numpy().tolist()
n_batches += 1
self.inference_time = time.time() - start_time
# Log results
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
logger.info("Finished inference.")
return np.array(scores)
def test(self, dataset: BaseADDataset, net: BaseNet): def test(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger() logger = logging.getLogger()
# Get test data loader # Get test data loader
_, test_loader = dataset.loaders( _, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )

View File

@@ -49,7 +49,7 @@ class SemiDeepGenerativeTrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get train data loader # Get train data loader
train_loader, _ = dataset.loaders( train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )
@@ -152,7 +152,7 @@ class SemiDeepGenerativeTrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get test data loader # Get test data loader
_, test_loader = dataset.loaders( _, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )

View File

@@ -44,7 +44,7 @@ class AETrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get train data loader # Get train data loader
train_loader, _ = dataset.loaders( train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )
@@ -115,7 +115,7 @@ class AETrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get test data loader # Get test data loader
_, test_loader = dataset.loaders( _, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )

View File

@@ -44,7 +44,7 @@ class VAETrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get train data loader # Get train data loader
train_loader, _ = dataset.loaders( train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )
@@ -117,7 +117,7 @@ class VAETrainer(BaseTrainer):
logger = logging.getLogger() logger = logging.getLogger()
# Get test data loader # Get test data loader
_, test_loader = dataset.loaders( _, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
) )