2024-06-28 07:42:12 +02:00
|
|
|
import click
|
|
|
|
|
import torch
|
|
|
|
|
import logging
|
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from utils.config import Config
|
|
|
|
|
from utils.visualization.plot_images_grid import plot_images_grid
|
|
|
|
|
from baselines.ocsvm import OCSVM
|
|
|
|
|
from datasets.main import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
################################################################################
|
|
|
|
|
# Settings
|
|
|
|
|
################################################################################
|
|
|
|
|
@click.command()
|
2024-06-28 11:36:46 +02:00
|
|
|
@click.argument(
|
|
|
|
|
"dataset_name",
|
|
|
|
|
type=click.Choice(
|
|
|
|
|
[
|
|
|
|
|
"mnist",
|
|
|
|
|
"fmnist",
|
|
|
|
|
"cifar10",
|
|
|
|
|
"arrhythmia",
|
|
|
|
|
"cardio",
|
|
|
|
|
"satellite",
|
|
|
|
|
"satimage-2",
|
|
|
|
|
"shuttle",
|
|
|
|
|
"thyroid",
|
|
|
|
|
]
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
@click.argument("xp_path", type=click.Path(exists=True))
|
|
|
|
|
@click.argument("data_path", type=click.Path(exists=True))
|
|
|
|
|
@click.option(
|
|
|
|
|
"--load_config",
|
|
|
|
|
type=click.Path(exists=True),
|
|
|
|
|
default=None,
|
|
|
|
|
help="Config JSON-file path (default: None).",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--load_model",
|
|
|
|
|
type=click.Path(exists=True),
|
|
|
|
|
default=None,
|
|
|
|
|
help="Model file path (default: None).",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--ratio_known_normal",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.0,
|
|
|
|
|
help="Ratio of known (labeled) normal training examples.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--ratio_known_outlier",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.0,
|
|
|
|
|
help="Ratio of known (labeled) anomalous training examples.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--ratio_pollution",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.0,
|
|
|
|
|
help="Pollution ratio of unlabeled training data with unknown (unlabeled) anomalies.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--seed", type=int, default=-1, help="Set seed. If -1, use randomization."
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--kernel",
|
|
|
|
|
type=click.Choice(["rbf", "linear", "poly"]),
|
|
|
|
|
default="rbf",
|
|
|
|
|
help="Kernel for the OC-SVM",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--nu",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.1,
|
|
|
|
|
help="OC-SVM hyperparameter nu (must be 0 < nu <= 1).",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--hybrid",
|
|
|
|
|
type=bool,
|
|
|
|
|
default=False,
|
|
|
|
|
help="Train OC-SVM on features extracted from an autoencoder. If True, load_ae must be specified.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--load_ae",
|
|
|
|
|
type=click.Path(exists=True),
|
|
|
|
|
default=None,
|
|
|
|
|
help="Model file path to load autoencoder weights (default: None).",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--n_jobs_dataloader",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="Number of workers for data loading. 0 means that the data will be loaded in the main process.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--normal_class",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="Specify the normal class of the dataset (all other classes are considered anomalous).",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--known_outlier_class",
|
|
|
|
|
type=int,
|
|
|
|
|
default=1,
|
|
|
|
|
help="Specify the known outlier class of the dataset for semi-supervised anomaly detection.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--n_known_outlier_classes",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="Number of known outlier classes."
|
|
|
|
|
"If 0, no anomalies are known."
|
|
|
|
|
"If 1, outlier class as specified in --known_outlier_class option."
|
|
|
|
|
"If > 1, the specified number of outlier classes will be sampled at random.",
|
|
|
|
|
)
|
|
|
|
|
def main(
|
|
|
|
|
dataset_name,
|
|
|
|
|
xp_path,
|
|
|
|
|
data_path,
|
|
|
|
|
load_config,
|
|
|
|
|
load_model,
|
|
|
|
|
ratio_known_normal,
|
|
|
|
|
ratio_known_outlier,
|
|
|
|
|
ratio_pollution,
|
|
|
|
|
seed,
|
|
|
|
|
kernel,
|
|
|
|
|
nu,
|
|
|
|
|
hybrid,
|
|
|
|
|
load_ae,
|
|
|
|
|
n_jobs_dataloader,
|
|
|
|
|
normal_class,
|
|
|
|
|
known_outlier_class,
|
|
|
|
|
n_known_outlier_classes,
|
|
|
|
|
):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""
|
|
|
|
|
(Hybrid) One-Class SVM for anomaly detection.
|
|
|
|
|
|
|
|
|
|
:arg DATASET_NAME: Name of the dataset to load.
|
|
|
|
|
:arg XP_PATH: Export path for logging the experiment.
|
|
|
|
|
:arg DATA_PATH: Root path of data.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Get configuration
|
|
|
|
|
cfg = Config(locals().copy())
|
|
|
|
|
|
|
|
|
|
# Set up logging
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
logger.setLevel(logging.INFO)
|
2024-06-28 11:36:46 +02:00
|
|
|
formatter = logging.Formatter(
|
|
|
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
|
)
|
|
|
|
|
log_file = xp_path + "/log.txt"
|
2024-06-28 07:42:12 +02:00
|
|
|
file_handler = logging.FileHandler(log_file)
|
|
|
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
|
file_handler.setFormatter(formatter)
|
|
|
|
|
logger.addHandler(file_handler)
|
|
|
|
|
|
|
|
|
|
# Print paths
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Log file is %s." % log_file)
|
|
|
|
|
logger.info("Data path is %s." % data_path)
|
|
|
|
|
logger.info("Export path is %s." % xp_path)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Print experimental setup
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Dataset: %s" % dataset_name)
|
|
|
|
|
logger.info("Normal class: %d" % normal_class)
|
|
|
|
|
logger.info("Ratio of labeled normal train samples: %.2f" % ratio_known_normal)
|
|
|
|
|
logger.info("Ratio of labeled anomalous samples: %.2f" % ratio_known_outlier)
|
|
|
|
|
logger.info("Pollution ratio of unlabeled train data: %.2f" % ratio_pollution)
|
2024-06-28 07:42:12 +02:00
|
|
|
if n_known_outlier_classes == 1:
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Known anomaly class: %d" % known_outlier_class)
|
2024-06-28 07:42:12 +02:00
|
|
|
else:
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Number of known anomaly classes: %d" % n_known_outlier_classes)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# If specified, load experiment config from JSON-file
|
|
|
|
|
if load_config:
|
|
|
|
|
cfg.load_config(import_json=load_config)
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Loaded configuration from %s." % load_config)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Print OC-SVM configuration
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("OC-SVM kernel: %s" % cfg.settings["kernel"])
|
|
|
|
|
logger.info("Nu-paramerter: %.2f" % cfg.settings["nu"])
|
|
|
|
|
logger.info("Hybrid model: %s" % cfg.settings["hybrid"])
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Set seed
|
2024-06-28 11:36:46 +02:00
|
|
|
if cfg.settings["seed"] != -1:
|
|
|
|
|
random.seed(cfg.settings["seed"])
|
|
|
|
|
np.random.seed(cfg.settings["seed"])
|
|
|
|
|
torch.manual_seed(cfg.settings["seed"])
|
|
|
|
|
torch.cuda.manual_seed(cfg.settings["seed"])
|
2024-06-28 07:42:12 +02:00
|
|
|
torch.backends.cudnn.deterministic = True
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Set seed to %d." % cfg.settings["seed"])
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Use 'cpu' as device for OC-SVM
|
2024-06-28 11:36:46 +02:00
|
|
|
device = "cpu"
|
|
|
|
|
torch.multiprocessing.set_sharing_strategy(
|
|
|
|
|
"file_system"
|
|
|
|
|
) # fix multiprocessing issue for ubuntu
|
|
|
|
|
logger.info("Computation device: %s" % device)
|
|
|
|
|
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Load data
|
2024-06-28 11:36:46 +02:00
|
|
|
dataset = load_dataset(
|
|
|
|
|
dataset_name,
|
|
|
|
|
data_path,
|
|
|
|
|
normal_class,
|
|
|
|
|
known_outlier_class,
|
|
|
|
|
n_known_outlier_classes,
|
|
|
|
|
ratio_known_normal,
|
|
|
|
|
ratio_known_outlier,
|
|
|
|
|
ratio_pollution,
|
|
|
|
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
# Log random sample of known anomaly classes if more than 1 class
|
|
|
|
|
if n_known_outlier_classes > 1:
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Initialize OC-SVM model
|
2024-06-28 11:36:46 +02:00
|
|
|
ocsvm = OCSVM(cfg.settings["kernel"], cfg.settings["nu"], cfg.settings["hybrid"])
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# If specified, load model parameters from already trained model
|
|
|
|
|
if load_model:
|
|
|
|
|
ocsvm.load_model(import_path=load_model, device=device)
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Loading model from %s." % load_model)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# If specified, load model autoencoder weights for a hybrid approach
|
|
|
|
|
if hybrid and load_ae is not None:
|
|
|
|
|
ocsvm.load_ae(dataset_name, model_path=load_ae)
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Loaded pretrained autoencoder for features from %s." % load_ae)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Train model on dataset
|
|
|
|
|
ocsvm.train(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)
|
|
|
|
|
|
|
|
|
|
# Test model
|
|
|
|
|
ocsvm.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)
|
|
|
|
|
|
|
|
|
|
# Save results and configuration
|
2024-06-28 11:36:46 +02:00
|
|
|
ocsvm.save_results(export_json=xp_path + "/results.json")
|
|
|
|
|
cfg.save_config(export_json=xp_path + "/config.json")
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Plot most anomalous and most normal test samples
|
2024-06-28 11:36:46 +02:00
|
|
|
indices, labels, scores = zip(*ocsvm.results["test_scores"])
|
2024-06-28 07:42:12 +02:00
|
|
|
indices, labels, scores = np.array(indices), np.array(labels), np.array(scores)
|
|
|
|
|
idx_all_sorted = indices[np.argsort(scores)] # from lowest to highest score
|
2024-06-28 11:36:46 +02:00
|
|
|
idx_normal_sorted = indices[labels == 0][
|
|
|
|
|
np.argsort(scores[labels == 0])
|
|
|
|
|
] # from lowest to highest score
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
if dataset_name in ("mnist", "fmnist", "cifar10"):
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
if dataset_name in ("mnist", "fmnist"):
|
2024-06-28 07:42:12 +02:00
|
|
|
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
|
|
|
|
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1)
|
2024-06-28 11:36:46 +02:00
|
|
|
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(
|
|
|
|
|
1
|
|
|
|
|
)
|
|
|
|
|
X_normal_high = dataset.test_set.data[
|
|
|
|
|
idx_normal_sorted[-32:], ...
|
|
|
|
|
].unsqueeze(1)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
if dataset_name == "cifar10":
|
|
|
|
|
X_all_low = torch.tensor(
|
|
|
|
|
np.transpose(
|
|
|
|
|
dataset.test_set.data[idx_all_sorted[:32], ...], (0, 3, 1, 2)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
X_all_high = torch.tensor(
|
|
|
|
|
np.transpose(
|
|
|
|
|
dataset.test_set.data[idx_all_sorted[-32:], ...], (0, 3, 1, 2)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
X_normal_low = torch.tensor(
|
|
|
|
|
np.transpose(
|
|
|
|
|
dataset.test_set.data[idx_normal_sorted[:32], ...], (0, 3, 1, 2)
|
|
|
|
|
)
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
X_normal_high = torch.tensor(
|
2024-06-28 11:36:46 +02:00
|
|
|
np.transpose(
|
|
|
|
|
dataset.test_set.data[idx_normal_sorted[-32:], ...], (0, 3, 1, 2)
|
|
|
|
|
)
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2)
|
|
|
|
|
plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2)
|
|
|
|
|
plot_images_grid(X_normal_low, export_img=xp_path + "/normals_low", padding=2)
|
|
|
|
|
plot_images_grid(X_normal_high, export_img=xp_path + "/normals_high", padding=2)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
if __name__ == "__main__":
|
2024-06-28 07:42:12 +02:00
|
|
|
main()
|