full upload so not to lose anything important
This commit is contained in:
@@ -5,6 +5,9 @@ from pathlib import Path
|
||||
import click
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from baselines.isoforest import IsoForest
|
||||
from baselines.ocsvm import OCSVM
|
||||
from datasets.main import load_dataset
|
||||
from DeepSAD import DeepSAD
|
||||
from utils.config import Config
|
||||
@@ -64,6 +67,30 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
)
|
||||
@click.argument("xp_path", type=click.Path(exists=True))
|
||||
@click.argument("data_path", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--k_fold",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use k-fold cross-validation for training (default: False).",
|
||||
)
|
||||
@click.option(
|
||||
"--k_fold_num",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of folds for k-fold cross-validation (default: 5).",
|
||||
)
|
||||
@click.option(
|
||||
"--num_known_normal",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of max known normal samples (semi-supervised-setting) (default: 0).",
|
||||
)
|
||||
@click.option(
|
||||
"--num_known_outlier",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of max known outlier samples (semi-supervised-setting) (default: 0).",
|
||||
)
|
||||
@click.option(
|
||||
"--load_config",
|
||||
type=click.Path(exists=True),
|
||||
@@ -214,12 +241,52 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
"If 1, outlier class as specified in --known_outlier_class option."
|
||||
"If > 1, the specified number of outlier classes will be sampled at random.",
|
||||
)
|
||||
@click.option(
|
||||
"--ocsvm_kernel",
|
||||
type=click.Choice(["rbf", "linear", "poly"]),
|
||||
default="rbf",
|
||||
help="Kernel for the OC-SVM",
|
||||
)
|
||||
@click.option(
|
||||
"--ocsvm_nu",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="OC-SVM hyperparameter nu (must be 0 < nu <= 1).",
|
||||
)
|
||||
@click.option(
|
||||
"--isoforest_n_estimators",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Set the number of base estimators in the ensemble (default: 100).",
|
||||
)
|
||||
@click.option(
|
||||
"--isoforest_max_samples",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Set the number of samples drawn to train each base estimator (default: 256).",
|
||||
)
|
||||
@click.option(
|
||||
"--isoforest_contamination",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Expected fraction of anomalies in the training set. (default: 0.1).",
|
||||
)
|
||||
@click.option(
|
||||
"--isoforest_n_jobs_model",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of jobs for model training.",
|
||||
)
|
||||
def main(
|
||||
action,
|
||||
dataset_name,
|
||||
net_name,
|
||||
xp_path,
|
||||
data_path,
|
||||
k_fold,
|
||||
k_fold_num,
|
||||
num_known_normal,
|
||||
num_known_outlier,
|
||||
load_config,
|
||||
load_model,
|
||||
eta,
|
||||
@@ -246,6 +313,12 @@ def main(
|
||||
normal_class,
|
||||
known_outlier_class,
|
||||
n_known_outlier_classes,
|
||||
ocsvm_kernel,
|
||||
ocsvm_nu,
|
||||
isoforest_n_estimators,
|
||||
isoforest_max_samples,
|
||||
isoforest_contamination,
|
||||
isoforest_n_jobs_model,
|
||||
):
|
||||
"""
|
||||
Deep SAD, a method for deep semi-supervised anomaly detection.
|
||||
@@ -318,6 +391,7 @@ def main(
|
||||
|
||||
if action == "train":
|
||||
# Load data
|
||||
# TODO: pass num of folds
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
data_path,
|
||||
@@ -328,135 +402,297 @@ def main(
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||
k_fold=k_fold,
|
||||
num_known_normal=num_known_normal,
|
||||
num_known_outlier=num_known_outlier,
|
||||
)
|
||||
# Log random sample of known anomaly classes if more than 1 class
|
||||
if n_known_outlier_classes > 1:
|
||||
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
||||
|
||||
# Initialize DeepSAD model and set neural network phi
|
||||
deepSAD = DeepSAD(cfg.settings["eta"])
|
||||
deepSAD.set_network(net_name)
|
||||
train_passes = range(k_fold_num) if k_fold else [None]
|
||||
|
||||
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
||||
if load_model:
|
||||
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
train_isoforest = True
|
||||
train_ocsvm = False
|
||||
train_deepsad = True
|
||||
|
||||
logger.info("Pretraining: %s" % pretrain)
|
||||
if pretrain:
|
||||
# Log pretraining details
|
||||
logger.info("Pretraining optimizer: %s" % cfg.settings["ae_optimizer_name"])
|
||||
logger.info("Pretraining learning rate: %g" % cfg.settings["ae_lr"])
|
||||
logger.info("Pretraining epochs: %d" % cfg.settings["ae_n_epochs"])
|
||||
for fold_idx in train_passes:
|
||||
if fold_idx is None:
|
||||
logger.info("Single training without k-fold")
|
||||
else:
|
||||
logger.info(f"Fold {fold_idx + 1}/{k_fold_num}")
|
||||
|
||||
# Initialize OC-SVM model
|
||||
if train_ocsvm:
|
||||
ocsvm = OCSVM(kernel=ocsvm_kernel, nu=ocsvm_nu, hybrid=False)
|
||||
|
||||
# Initialize Isolation Forest model
|
||||
if train_isoforest:
|
||||
Isoforest = IsoForest(
|
||||
hybrid=False,
|
||||
n_estimators=isoforest_n_estimators,
|
||||
max_samples=isoforest_max_samples,
|
||||
contamination=isoforest_contamination,
|
||||
n_jobs=isoforest_n_jobs_model,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Initialize DeepSAD model and set neural network phi
|
||||
if train_deepsad:
|
||||
deepSAD = DeepSAD(cfg.settings["eta"])
|
||||
deepSAD.set_network(net_name)
|
||||
|
||||
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
||||
if train_deepsad and load_model:
|
||||
deepSAD.load_model(
|
||||
model_path=load_model, load_ae=True, map_location=device
|
||||
)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
|
||||
logger.info("Pretraining: %s" % pretrain)
|
||||
if train_deepsad and pretrain:
|
||||
# Log pretraining details
|
||||
logger.info(
|
||||
"Pretraining optimizer: %s" % cfg.settings["ae_optimizer_name"]
|
||||
)
|
||||
logger.info("Pretraining learning rate: %g" % cfg.settings["ae_lr"])
|
||||
logger.info("Pretraining epochs: %d" % cfg.settings["ae_n_epochs"])
|
||||
logger.info(
|
||||
"Pretraining learning rate scheduler milestones: %s"
|
||||
% (cfg.settings["ae_lr_milestone"],)
|
||||
)
|
||||
logger.info(
|
||||
"Pretraining batch size: %d" % cfg.settings["ae_batch_size"]
|
||||
)
|
||||
logger.info(
|
||||
"Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]
|
||||
)
|
||||
|
||||
# Pretrain model on dataset (via autoencoder)
|
||||
deepSAD.pretrain(
|
||||
dataset,
|
||||
optimizer_name=cfg.settings["ae_optimizer_name"],
|
||||
lr=cfg.settings["ae_lr"],
|
||||
n_epochs=cfg.settings["ae_n_epochs"],
|
||||
lr_milestones=cfg.settings["ae_lr_milestone"],
|
||||
batch_size=cfg.settings["ae_batch_size"],
|
||||
weight_decay=cfg.settings["ae_weight_decay"],
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Save pretraining results
|
||||
if fold_idx is None:
|
||||
deepSAD.save_ae_results(export_json=xp_path + "/ae_results.json")
|
||||
else:
|
||||
deepSAD.save_ae_results(
|
||||
export_json=xp_path + f"/ae_results_{fold_idx}.json"
|
||||
)
|
||||
|
||||
# Log training details
|
||||
logger.info("Training optimizer: %s" % cfg.settings["optimizer_name"])
|
||||
logger.info("Training learning rate: %g" % cfg.settings["lr"])
|
||||
logger.info("Training epochs: %d" % cfg.settings["n_epochs"])
|
||||
logger.info(
|
||||
"Pretraining learning rate scheduler milestones: %s"
|
||||
% (cfg.settings["ae_lr_milestone"],)
|
||||
)
|
||||
logger.info("Pretraining batch size: %d" % cfg.settings["ae_batch_size"])
|
||||
logger.info(
|
||||
"Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]
|
||||
"Training learning rate scheduler milestones: %s"
|
||||
% (cfg.settings["lr_milestone"],)
|
||||
)
|
||||
logger.info("Training batch size: %d" % cfg.settings["batch_size"])
|
||||
logger.info("Training weight decay: %g" % cfg.settings["weight_decay"])
|
||||
|
||||
# Pretrain model on dataset (via autoencoder)
|
||||
deepSAD.pretrain(
|
||||
dataset,
|
||||
optimizer_name=cfg.settings["ae_optimizer_name"],
|
||||
lr=cfg.settings["ae_lr"],
|
||||
n_epochs=cfg.settings["ae_n_epochs"],
|
||||
lr_milestones=cfg.settings["ae_lr_milestone"],
|
||||
batch_size=cfg.settings["ae_batch_size"],
|
||||
weight_decay=cfg.settings["ae_weight_decay"],
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
)
|
||||
|
||||
# Save pretraining results
|
||||
deepSAD.save_ae_results(export_json=xp_path + "/ae_results.json")
|
||||
|
||||
# Log training details
|
||||
logger.info("Training optimizer: %s" % cfg.settings["optimizer_name"])
|
||||
logger.info("Training learning rate: %g" % cfg.settings["lr"])
|
||||
logger.info("Training epochs: %d" % cfg.settings["n_epochs"])
|
||||
logger.info(
|
||||
"Training learning rate scheduler milestones: %s"
|
||||
% (cfg.settings["lr_milestone"],)
|
||||
)
|
||||
logger.info("Training batch size: %d" % cfg.settings["batch_size"])
|
||||
logger.info("Training weight decay: %g" % cfg.settings["weight_decay"])
|
||||
|
||||
# Train model on dataset
|
||||
deepSAD.train(
|
||||
dataset,
|
||||
optimizer_name=cfg.settings["optimizer_name"],
|
||||
lr=cfg.settings["lr"],
|
||||
n_epochs=cfg.settings["n_epochs"],
|
||||
lr_milestones=cfg.settings["lr_milestone"],
|
||||
batch_size=cfg.settings["batch_size"],
|
||||
weight_decay=cfg.settings["weight_decay"],
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
)
|
||||
|
||||
# Test model
|
||||
deepSAD.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)
|
||||
|
||||
# Save results, model, and configuration
|
||||
deepSAD.save_results(export_json=xp_path + "/results.json")
|
||||
deepSAD.save_model(export_model=xp_path + "/model.tar")
|
||||
cfg.save_config(export_json=xp_path + "/config.json")
|
||||
|
||||
# Plot most anomalous and most normal test samples
|
||||
indices, labels, scores = zip(*deepSAD.results["test_scores"])
|
||||
indices, labels, scores = np.array(indices), np.array(labels), np.array(scores)
|
||||
idx_all_sorted = indices[np.argsort(scores)] # from lowest to highest score
|
||||
idx_normal_sorted = indices[labels == 0][
|
||||
np.argsort(scores[labels == 0])
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
||||
1
|
||||
# Train model on dataset
|
||||
if train_deepsad:
|
||||
deepSAD.train(
|
||||
dataset,
|
||||
optimizer_name=cfg.settings["optimizer_name"],
|
||||
lr=cfg.settings["lr"],
|
||||
n_epochs=cfg.settings["n_epochs"],
|
||||
lr_milestones=cfg.settings["lr_milestone"],
|
||||
batch_size=cfg.settings["batch_size"],
|
||||
weight_decay=cfg.settings["weight_decay"],
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
X_normal_low = dataset.test_set.data[
|
||||
idx_normal_sorted[:32], ...
|
||||
].unsqueeze(1)
|
||||
X_normal_high = dataset.test_set.data[
|
||||
idx_normal_sorted[-32:], ...
|
||||
].unsqueeze(1)
|
||||
|
||||
if dataset_name == "cifar10":
|
||||
X_all_low = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_all_sorted[:32], ...], (0, 3, 1, 2)
|
||||
# Train model on dataset
|
||||
if train_ocsvm:
|
||||
ocsvm.train(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
# Train model on dataset
|
||||
if train_isoforest:
|
||||
Isoforest.train(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Test model
|
||||
if train_deepsad:
|
||||
deepSAD.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Test model
|
||||
if train_ocsvm:
|
||||
ocsvm.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
# Test model
|
||||
if train_isoforest:
|
||||
Isoforest.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Save results, model, and configuration
|
||||
if fold_idx is None:
|
||||
if train_deepsad:
|
||||
deepSAD.save_results(export_pkl=xp_path + "/results.pkl")
|
||||
deepSAD.save_model(export_model=xp_path + "/model.tar")
|
||||
if train_ocsvm:
|
||||
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
|
||||
if train_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=xp_path + "/results_isoforest.pkl"
|
||||
)
|
||||
)
|
||||
X_all_high = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_all_sorted[-32:], ...], (0, 3, 1, 2)
|
||||
else:
|
||||
if train_deepsad:
|
||||
deepSAD.save_results(
|
||||
export_pkl=xp_path + f"/results_{fold_idx}.pkl"
|
||||
)
|
||||
)
|
||||
X_normal_low = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_normal_sorted[:32], ...], (0, 3, 1, 2)
|
||||
deepSAD.save_model(export_model=xp_path + f"/model_{fold_idx}.tar")
|
||||
if train_ocsvm:
|
||||
ocsvm.save_results(
|
||||
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
|
||||
)
|
||||
)
|
||||
X_normal_high = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_normal_sorted[-32:], ...],
|
||||
(0, 3, 1, 2),
|
||||
if train_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=xp_path + f"/results_isoforest_{fold_idx}.pkl"
|
||||
)
|
||||
)
|
||||
|
||||
plot_images_grid(X_all_low, export_img=xp_path + "/all_low", padding=2)
|
||||
plot_images_grid(X_all_high, export_img=xp_path + "/all_high", padding=2)
|
||||
plot_images_grid(
|
||||
X_normal_low, export_img=xp_path + "/normals_low", padding=2
|
||||
)
|
||||
plot_images_grid(
|
||||
X_normal_high, export_img=xp_path + "/normals_high", padding=2
|
||||
)
|
||||
cfg.save_config(export_json=xp_path + "/config.json")
|
||||
|
||||
# Plot most anomalous and most normal test samples
|
||||
if train_deepsad:
|
||||
indices, labels, scores = zip(*deepSAD.results["test_scores"])
|
||||
indices, labels, scores = (
|
||||
np.array(indices),
|
||||
np.array(labels),
|
||||
np.array(scores),
|
||||
)
|
||||
idx_all_sorted = indices[
|
||||
np.argsort(scores)
|
||||
] # from lowest to highest score
|
||||
idx_normal_sorted = indices[labels == 0][
|
||||
np.argsort(scores[labels == 0])
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in (
|
||||
"mnist",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"elpv",
|
||||
):
|
||||
if dataset_name in (
|
||||
"mnist",
|
||||
"fmnist",
|
||||
"elpv",
|
||||
):
|
||||
X_all_low = dataset.test_set.data[
|
||||
idx_all_sorted[:32], ...
|
||||
].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[
|
||||
idx_all_sorted[-32:], ...
|
||||
].unsqueeze(1)
|
||||
X_normal_low = dataset.test_set.data[
|
||||
idx_normal_sorted[:32], ...
|
||||
].unsqueeze(1)
|
||||
X_normal_high = dataset.test_set.data[
|
||||
idx_normal_sorted[-32:], ...
|
||||
].unsqueeze(1)
|
||||
|
||||
if dataset_name == "cifar10":
|
||||
X_all_low = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_all_sorted[:32], ...],
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
)
|
||||
X_all_high = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_all_sorted[-32:], ...],
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
)
|
||||
X_normal_low = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_normal_sorted[:32], ...],
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
)
|
||||
X_normal_high = torch.tensor(
|
||||
np.transpose(
|
||||
dataset.test_set.data[idx_normal_sorted[-32:], ...],
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if fold_idx is None:
|
||||
plot_images_grid(
|
||||
X_all_low, export_img=xp_path + "/all_low", padding=2
|
||||
)
|
||||
plot_images_grid(
|
||||
X_all_high, export_img=xp_path + "/all_high", padding=2
|
||||
)
|
||||
plot_images_grid(
|
||||
X_normal_low, export_img=xp_path + "/normals_low", padding=2
|
||||
)
|
||||
plot_images_grid(
|
||||
X_normal_high,
|
||||
export_img=xp_path + "/normals_high",
|
||||
padding=2,
|
||||
)
|
||||
else:
|
||||
plot_images_grid(
|
||||
X_all_low,
|
||||
export_img=xp_path + f"/all_low_{fold_idx}",
|
||||
padding=2,
|
||||
)
|
||||
plot_images_grid(
|
||||
X_all_high,
|
||||
export_img=xp_path + f"/all_high_{fold_idx}",
|
||||
padding=2,
|
||||
)
|
||||
plot_images_grid(
|
||||
X_normal_low,
|
||||
export_img=xp_path + f"/normals_low_{fold_idx}",
|
||||
padding=2,
|
||||
)
|
||||
plot_images_grid(
|
||||
X_normal_high,
|
||||
export_img=xp_path + f"/normals_high_{fold_idx}",
|
||||
padding=2,
|
||||
)
|
||||
|
||||
elif action == "infer":
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
@@ -488,14 +724,23 @@ def main(
|
||||
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
|
||||
inference_results = deepSAD.inference(
|
||||
inference_results, all_outputs = deepSAD.inference(
|
||||
dataset, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||
)
|
||||
inference_results_path = (
|
||||
Path(xp_path) / "inference" / Path(dataset.root).with_suffix(".npy").stem
|
||||
Path(xp_path)
|
||||
/ "inference"
|
||||
/ Path(Path(dataset.root).stem).with_suffix(".npy")
|
||||
)
|
||||
inference_outputs_path = (
|
||||
Path(xp_path)
|
||||
/ "inference"
|
||||
/ Path(Path(dataset.root).stem + "_outputs").with_suffix(".npy")
|
||||
)
|
||||
|
||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.save(inference_results_path, inference_results, fix_imports=False)
|
||||
np.save(inference_outputs_path, all_outputs, fix_imports=False)
|
||||
|
||||
logger.info(
|
||||
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
|
||||
|
||||
Reference in New Issue
Block a user