Files
mt/Deep-SAD-PyTorch/src/main.py

780 lines
25 KiB
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
import logging
2025-06-10 13:58:38 +02:00
import pickle
2024-06-28 07:42:12 +02:00
import random
2024-07-04 15:36:01 +02:00
from pathlib import Path
2024-06-28 07:42:12 +02:00
import click
import numpy as np
import torch
from baselines.isoforest import IsoForest
from baselines.ocsvm import OCSVM
from datasets.main import load_dataset
from DeepSAD import DeepSAD
2024-06-28 07:42:12 +02:00
from utils.config import Config
from utils.visualization.plot_images_grid import plot_images_grid
################################################################################
# Settings
################################################################################
@click.command()
2024-07-04 15:36:01 +02:00
@click.argument(
"action",
type=click.Choice(
[
"train",
"infer",
2025-06-10 09:31:28 +02:00
"ae_elbow_test", # Add new action
2024-07-04 15:36:01 +02:00
]
),
)
2024-06-28 11:36:46 +02:00
@click.argument(
"dataset_name",
type=click.Choice(
[
"mnist",
"elpv",
"subter",
"subtersplit",
2024-06-28 11:36:46 +02:00
"fmnist",
"cifar10",
"arrhythmia",
"cardio",
"satellite",
"satimage-2",
"shuttle",
"thyroid",
]
),
)
@click.argument(
"net_name",
type=click.Choice(
[
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
2025-06-17 07:26:03 +02:00
"subter_efficient",
"subter_LeNet_Split",
2024-06-28 11:36:46 +02:00
"fmnist_LeNet",
"cifar10_LeNet",
"arrhythmia_mlp",
"cardio_mlp",
"satellite_mlp",
"satimage-2_mlp",
"shuttle_mlp",
"thyroid_mlp",
]
),
)
@click.argument("xp_path", type=click.Path(exists=True))
@click.argument("data_path", type=click.Path(exists=True))
@click.option(
"--k_fold",
type=bool,
default=False,
help="Use k-fold cross-validation for training (default: False).",
)
@click.option(
"--k_fold_num",
type=int,
2025-06-10 09:31:28 +02:00
default=None,
help="Number of folds for k-fold cross-validation (default: None).",
)
@click.option(
"--num_known_normal",
type=int,
default=0,
help="Number of max known normal samples (semi-supervised-setting) (default: 0).",
)
@click.option(
"--num_known_outlier",
type=int,
default=0,
help="Number of max known outlier samples (semi-supervised-setting) (default: 0).",
)
2024-06-28 11:36:46 +02:00
@click.option(
"--load_config",
type=click.Path(exists=True),
default=None,
help="Config JSON-file path (default: None).",
)
@click.option(
"--load_model",
type=click.Path(exists=True),
default=None,
help="Model file path (default: None).",
)
@click.option(
"--eta",
type=float,
default=1.0,
help="Deep SAD hyperparameter eta (must be 0 < eta).",
)
@click.option(
"--ratio_known_normal",
type=float,
default=0.0,
help="Ratio of known (labeled) normal training examples.",
)
@click.option(
"--ratio_known_outlier",
type=float,
default=0.0,
help="Ratio of known (labeled) anomalous training examples.",
)
@click.option(
"--ratio_pollution",
type=float,
default=0.0,
help="Pollution ratio of unlabeled training data with unknown (unlabeled) anomalies.",
)
@click.option(
"--device",
type=str,
default="cuda",
help='Computation device to use ("cpu", "cuda", "cuda:2", etc.).',
)
@click.option(
"--seed", type=int, default=-1, help="Set seed. If -1, use randomization."
)
@click.option(
"--optimizer_name",
type=click.Choice(["adam"]),
default="adam",
help="Name of the optimizer to use for Deep SAD network training.",
)
@click.option(
"--lr",
type=float,
default=0.001,
help="Initial learning rate for Deep SAD network training. Default=0.001",
)
@click.option("--n_epochs", type=int, default=50, help="Number of epochs to train.")
@click.option(
"--lr_milestone",
type=int,
default=[0],
2024-06-28 11:36:46 +02:00
multiple=True,
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
)
@click.option(
"--batch_size", type=int, default=128, help="Batch size for mini-batch training."
)
@click.option(
"--weight_decay",
type=float,
default=1e-6,
help="Weight decay (L2 penalty) hyperparameter for Deep SAD objective.",
)
2025-06-13 10:24:54 +02:00
@click.option(
"--latent_space_dim",
type=int,
default=128,
help="Dimensionality of the latent space for the autoencoder.",
)
2024-06-28 11:36:46 +02:00
@click.option(
"--pretrain",
type=bool,
default=True,
help="Pretrain neural network parameters via autoencoder.",
)
@click.option(
"--ae_optimizer_name",
type=click.Choice(["adam"]),
default="adam",
help="Name of the optimizer to use for autoencoder pretraining.",
)
@click.option(
"--ae_lr",
type=float,
default=0.001,
help="Initial learning rate for autoencoder pretraining. Default=0.001",
)
@click.option(
"--ae_n_epochs",
type=int,
default=100,
help="Number of epochs to train autoencoder.",
)
@click.option(
"--ae_lr_milestone",
type=int,
default=[0],
2024-06-28 11:36:46 +02:00
multiple=True,
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
)
@click.option(
"--ae_batch_size",
type=int,
default=128,
help="Batch size for mini-batch autoencoder training.",
)
@click.option(
"--ae_weight_decay",
type=float,
default=1e-6,
help="Weight decay (L2 penalty) hyperparameter for autoencoder objective.",
)
@click.option(
"--num_threads",
type=int,
default=0,
help="Number of threads used for parallelizing CPU operations. 0 means that all resources are used.",
)
@click.option(
"--n_jobs_dataloader",
type=int,
default=0,
help="Number of workers for data loading. 0 means that the data will be loaded in the main process.",
)
@click.option(
"--normal_class",
type=int,
default=0,
help="Specify the normal class of the dataset (all other classes are considered anomalous).",
)
@click.option(
"--known_outlier_class",
type=int,
default=1,
help="Specify the known outlier class of the dataset for semi-supervised anomaly detection.",
)
@click.option(
"--n_known_outlier_classes",
type=int,
default=0,
help="Number of known outlier classes."
"If 0, no anomalies are known."
"If 1, outlier class as specified in --known_outlier_class option."
"If > 1, the specified number of outlier classes will be sampled at random.",
)
@click.option(
"--ocsvm_kernel",
type=click.Choice(["rbf", "linear", "poly"]),
default="rbf",
help="Kernel for the OC-SVM",
)
@click.option(
"--ocsvm_nu",
type=float,
default=0.1,
help="OC-SVM hyperparameter nu (must be 0 < nu <= 1).",
)
@click.option(
"--isoforest_n_estimators",
type=int,
default=100,
help="Set the number of base estimators in the ensemble (default: 100).",
)
@click.option(
"--isoforest_max_samples",
type=int,
default=256,
help="Set the number of samples drawn to train each base estimator (default: 256).",
)
@click.option(
"--isoforest_contamination",
type=float,
default=0.1,
help="Expected fraction of anomalies in the training set. (default: 0.1).",
)
@click.option(
"--isoforest_n_jobs_model",
type=int,
default=-1,
help="Number of jobs for model training.",
)
2024-06-28 11:36:46 +02:00
def main(
2024-07-04 15:36:01 +02:00
action,
2024-06-28 11:36:46 +02:00
dataset_name,
net_name,
xp_path,
data_path,
k_fold,
k_fold_num,
num_known_normal,
num_known_outlier,
2024-06-28 11:36:46 +02:00
load_config,
load_model,
eta,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
device,
seed,
optimizer_name,
lr,
n_epochs,
lr_milestone,
batch_size,
weight_decay,
2025-06-13 10:24:54 +02:00
latent_space_dim,
2024-06-28 11:36:46 +02:00
pretrain,
ae_optimizer_name,
ae_lr,
ae_n_epochs,
ae_lr_milestone,
ae_batch_size,
ae_weight_decay,
num_threads,
n_jobs_dataloader,
normal_class,
known_outlier_class,
n_known_outlier_classes,
ocsvm_kernel,
ocsvm_nu,
isoforest_n_estimators,
isoforest_max_samples,
isoforest_contamination,
isoforest_n_jobs_model,
2024-06-28 11:36:46 +02:00
):
2024-06-28 07:42:12 +02:00
"""
Deep SAD, a method for deep semi-supervised anomaly detection.
:arg DATASET_NAME: Name of the dataset to load.
:arg NET_NAME: Name of the neural network to use.
:arg XP_PATH: Export path for logging the experiment.
:arg DATA_PATH: Root path of data.
"""
# Get configuration
cfg = Config(locals().copy())
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
2024-06-28 11:36:46 +02:00
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
log_file = xp_path + "/log.txt"
2024-06-28 07:42:12 +02:00
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Print paths
2024-06-28 11:36:46 +02:00
logger.info("Log file is %s" % log_file)
logger.info("Data path is %s" % data_path)
logger.info("Export path is %s" % xp_path)
2024-06-28 07:42:12 +02:00
# Print experimental setup
2024-06-28 11:36:46 +02:00
logger.info("Dataset: %s" % dataset_name)
logger.info("Normal class: %d" % normal_class)
logger.info("Ratio of labeled normal train samples: %.2f" % ratio_known_normal)
logger.info("Ratio of labeled anomalous samples: %.2f" % ratio_known_outlier)
logger.info("Pollution ratio of unlabeled train data: %.2f" % ratio_pollution)
2024-06-28 07:42:12 +02:00
if n_known_outlier_classes == 1:
2024-06-28 11:36:46 +02:00
logger.info("Known anomaly class: %d" % known_outlier_class)
2024-06-28 07:42:12 +02:00
else:
2024-06-28 11:36:46 +02:00
logger.info("Number of known anomaly classes: %d" % n_known_outlier_classes)
logger.info("Network: %s" % net_name)
2024-06-28 07:42:12 +02:00
# If specified, load experiment config from JSON-file
if load_config:
cfg.load_config(import_json=load_config)
2024-06-28 11:36:46 +02:00
logger.info("Loaded configuration from %s." % load_config)
2024-06-28 07:42:12 +02:00
# Print model configuration
2024-06-28 11:36:46 +02:00
logger.info("Eta-parameter: %.2f" % cfg.settings["eta"])
2024-06-28 07:42:12 +02:00
# Set seed
2024-06-28 11:36:46 +02:00
if cfg.settings["seed"] != -1:
random.seed(cfg.settings["seed"])
np.random.seed(cfg.settings["seed"])
torch.manual_seed(cfg.settings["seed"])
torch.cuda.manual_seed(cfg.settings["seed"])
2024-06-28 07:42:12 +02:00
torch.backends.cudnn.deterministic = True
2024-06-28 11:36:46 +02:00
logger.info("Set seed to %d." % cfg.settings["seed"])
2024-06-28 07:42:12 +02:00
# Default device to 'cpu' if cuda is not available
if not torch.cuda.is_available():
2024-06-28 11:36:46 +02:00
device = "cpu"
2024-06-28 07:42:12 +02:00
# Set the number of threads used for parallelizing CPU operations
if num_threads > 0:
torch.set_num_threads(num_threads)
2024-06-28 11:36:46 +02:00
logger.info("Computation device: %s" % device)
logger.info("Number of threads: %d" % num_threads)
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
2024-06-28 07:42:12 +02:00
2024-07-04 15:36:01 +02:00
if action == "train":
# Load data
# TODO: pass num of folds
2024-07-04 15:36:01 +02:00
dataset = load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
random_state=np.random.RandomState(cfg.settings["seed"]),
2025-06-10 09:31:28 +02:00
k_fold_num=k_fold_num,
num_known_normal=num_known_normal,
num_known_outlier=num_known_outlier,
2024-07-04 15:36:01 +02:00
)
# Log random sample of known anomaly classes if more than 1 class
if n_known_outlier_classes > 1:
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
2024-06-28 07:42:12 +02:00
train_passes = range(k_fold_num) if k_fold else [None]
train_isoforest = True
2025-06-13 10:24:54 +02:00
train_ocsvm = True
train_deepsad = True
for fold_idx in train_passes:
if fold_idx is None:
logger.info("Single training without k-fold")
else:
logger.info(f"Fold {fold_idx + 1}/{k_fold_num}")
# Initialize Isolation Forest model
if train_isoforest:
Isoforest = IsoForest(
hybrid=False,
n_estimators=isoforest_n_estimators,
max_samples=isoforest_max_samples,
contamination=isoforest_contamination,
n_jobs=isoforest_n_jobs_model,
seed=seed,
)
2024-07-04 15:36:01 +02:00
# Initialize DeepSAD model and set neural network phi
if train_deepsad:
2025-06-13 10:24:54 +02:00
deepSAD = DeepSAD(latent_space_dim, cfg.settings["eta"])
deepSAD.set_network(net_name)
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
if train_deepsad and load_model:
deepSAD.load_model(
model_path=load_model, load_ae=True, map_location=device
)
logger.info("Loading model from %s." % load_model)
logger.info("Pretraining: %s" % pretrain)
if train_deepsad and pretrain:
# Log pretraining details
logger.info(
"Pretraining optimizer: %s" % cfg.settings["ae_optimizer_name"]
)
logger.info("Pretraining learning rate: %g" % cfg.settings["ae_lr"])
logger.info("Pretraining epochs: %d" % cfg.settings["ae_n_epochs"])
logger.info(
"Pretraining learning rate scheduler milestones: %s"
% (cfg.settings["ae_lr_milestone"],)
)
logger.info(
"Pretraining batch size: %d" % cfg.settings["ae_batch_size"]
)
logger.info(
"Pretraining weight decay: %g" % cfg.settings["ae_weight_decay"]
)
# Pretrain model on dataset (via autoencoder)
deepSAD.pretrain(
dataset,
optimizer_name=cfg.settings["ae_optimizer_name"],
lr=cfg.settings["ae_lr"],
n_epochs=cfg.settings["ae_n_epochs"],
lr_milestones=cfg.settings["ae_lr_milestone"],
batch_size=cfg.settings["ae_batch_size"],
weight_decay=cfg.settings["ae_weight_decay"],
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
)
# Save pretraining results
if fold_idx is None:
deepSAD.save_ae_results(export_pkl=xp_path + "/results_ae.pkl")
ae_model_path = xp_path + "/model_ae.tar"
2025-06-13 10:24:54 +02:00
deepSAD.save_model(export_model=ae_model_path, save_ae=True)
else:
deepSAD.save_ae_results(
export_pkl=xp_path + f"/results_ae_{fold_idx}.pkl"
)
ae_model_path = xp_path + f"/model_ae_{fold_idx}.tar"
2025-06-13 10:24:54 +02:00
deepSAD.save_model(export_model=ae_model_path, save_ae=True)
# Initialize OC-SVM model (after pretraining to use autoencoder features)
if train_ocsvm:
ocsvm = OCSVM(
kernel=ocsvm_kernel,
nu=ocsvm_nu,
hybrid=True,
latent_space_dim=latent_space_dim,
)
if load_model and not pretrain:
ae_model_path = load_model
ocsvm.load_ae(
net_name=net_name, model_path=ae_model_path, device=device
)
logger.info(
f"Loaded pretrained autoencoder for features from {ae_model_path}."
)
# Log training details
logger.info("Training optimizer: %s" % cfg.settings["optimizer_name"])
logger.info("Training learning rate: %g" % cfg.settings["lr"])
logger.info("Training epochs: %d" % cfg.settings["n_epochs"])
2024-07-04 15:36:01 +02:00
logger.info(
"Training learning rate scheduler milestones: %s"
% (cfg.settings["lr_milestone"],)
2024-07-04 15:36:01 +02:00
)
logger.info("Training batch size: %d" % cfg.settings["batch_size"])
logger.info("Training weight decay: %g" % cfg.settings["weight_decay"])
# Train model on dataset
if train_deepsad:
deepSAD.train(
dataset,
optimizer_name=cfg.settings["optimizer_name"],
lr=cfg.settings["lr"],
n_epochs=cfg.settings["n_epochs"],
lr_milestones=cfg.settings["lr_milestone"],
batch_size=cfg.settings["batch_size"],
weight_decay=cfg.settings["weight_decay"],
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
)
2024-07-04 15:36:01 +02:00
# Train model on dataset
if train_ocsvm:
ocsvm.train(
dataset,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
2025-06-13 10:24:54 +02:00
batch_size=256,
)
2024-07-04 15:36:01 +02:00
# Train model on dataset
if train_isoforest:
Isoforest.train(
dataset,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
)
2024-07-04 15:36:01 +02:00
# Test model
if train_deepsad:
deepSAD.test(
dataset,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
)
2024-06-28 07:42:12 +02:00
# Test model
if train_ocsvm:
ocsvm.test(
dataset,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
2025-06-13 10:24:54 +02:00
batch_size=256,
2024-06-28 11:36:46 +02:00
)
# Test model
if train_isoforest:
Isoforest.test(
dataset,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
2024-06-28 11:36:46 +02:00
)
# Save results, model, and configuration
if fold_idx is None:
if train_deepsad:
deepSAD.save_results(export_pkl=xp_path + "/results_deepsad.pkl")
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
if train_ocsvm:
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin")
if train_isoforest:
Isoforest.save_results(
export_pkl=xp_path + "/results_isoforest.pkl"
2024-07-04 15:36:01 +02:00
)
Isoforest.save_model(export_path=xp_path + "/model_isoforest.pkl")
else:
if train_deepsad:
deepSAD.save_results(
export_pkl=xp_path + f"/results_deepsad_{fold_idx}.pkl"
)
deepSAD.save_model(
export_model=xp_path + f"/model_deepsad_{fold_idx}.tar"
2024-07-04 15:36:01 +02:00
)
if train_ocsvm:
ocsvm.save_results(
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
)
ocsvm.save_model(
export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin"
)
if train_isoforest:
Isoforest.save_results(
export_pkl=xp_path + f"/results_isoforest_{fold_idx}.pkl"
2024-07-04 15:36:01 +02:00
)
Isoforest.save_model(
export_path=xp_path + f"/model_isoforest_{fold_idx}.pkl"
)
cfg.save_config(export_json=xp_path + "/config.json")
2024-07-04 15:36:01 +02:00
elif action == "infer":
dataset = load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
random_state=np.random.RandomState(cfg.settings["seed"]),
inference=True,
)
# Log random sample of known anomaly classes if more than 1 class
if n_known_outlier_classes > 1:
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
2024-06-28 07:42:12 +02:00
2024-07-04 15:36:01 +02:00
# Initialize DeepSAD model and set neural network phi
2025-06-13 10:24:54 +02:00
deepSAD = DeepSAD(latent_space_dim, cfg.settings["eta"])
2024-07-04 15:36:01 +02:00
deepSAD.set_network(net_name)
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
if not load_model:
logger.error(
"For inference mode a model has to be loaded! Pass the --load_model option with the model path!"
)
return
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
logger.info("Loading model from %s." % load_model)
inference_results, all_outputs = deepSAD.inference(
2024-07-04 15:36:01 +02:00
dataset, device=device, n_jobs_dataloader=n_jobs_dataloader
)
inference_results_path = (
Path(xp_path)
/ "inference"
/ Path(Path(dataset.root).stem).with_suffix(".npy")
2024-07-04 15:36:01 +02:00
)
inference_outputs_path = (
Path(xp_path)
/ "inference"
/ Path(Path(dataset.root).stem + "_outputs").with_suffix(".npy")
)
2024-07-04 15:36:01 +02:00
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
np.save(inference_results_path, inference_results, fix_imports=False)
np.save(inference_outputs_path, all_outputs, fix_imports=False)
2024-07-04 15:36:01 +02:00
logger.info(
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
)
2025-06-10 09:31:28 +02:00
elif action == "ae_elbow_test":
# Load data once
dataset = load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
random_state=np.random.RandomState(cfg.settings["seed"]),
2025-06-13 10:24:54 +02:00
k_fold_num=k_fold_num,
2025-06-10 09:31:28 +02:00
)
2025-06-13 10:24:54 +02:00
# Set up k-fold passes
train_passes = range(k_fold_num) if k_fold else [None]
# Test dimensions
ae_elbow_dims = [32, 64, 128, 256, 384, 512, 768, 1024]
2025-06-10 09:31:28 +02:00
# Test each dimension
for rep_dim in ae_elbow_dims:
logger.info(f"Testing autoencoder with latent dimension: {rep_dim}")
2025-06-13 10:24:54 +02:00
# Results dictionary for this dimension
dim_results = {
"dimension": rep_dim,
"ae_results": {},
"k_fold": k_fold,
"k_fold_num": k_fold_num,
}
# For each fold
for fold_idx in train_passes:
if fold_idx is None:
logger.info(f"Dimension {rep_dim}: Single training without k-fold")
else:
logger.info(
f"Dimension {rep_dim}: Fold {fold_idx + 1}/{k_fold_num}"
)
# Initialize DeepSAD model with current dimension
deepSAD = DeepSAD(rep_dim, cfg.settings["eta"])
deepSAD.set_network(net_name)
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
# Pretrain autoencoder with current dimension
deepSAD.pretrain(
dataset,
optimizer_name=cfg.settings["ae_optimizer_name"],
lr=cfg.settings["ae_lr"],
n_epochs=cfg.settings["ae_n_epochs"],
lr_milestones=cfg.settings["ae_lr_milestone"],
batch_size=cfg.settings["ae_batch_size"],
weight_decay=cfg.settings["ae_weight_decay"],
device=device,
n_jobs_dataloader=n_jobs_dataloader,
k_fold_idx=fold_idx,
)
# Store results for this fold
fold_key = "single" if fold_idx is None else f"fold_{fold_idx}"
dim_results["ae_results"][fold_key] = deepSAD.ae_results
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
logger.info(
f"Finished testing dimension {rep_dim} "
+ (
f"fold {fold_idx + 1}/{k_fold_num}"
if fold_idx is not None
else "single pass"
)
)
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
# Clear some memory
del deepSAD
torch.cuda.empty_cache()
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
# Save results for this dimension (includes all folds)
results_filename = (
f"ae_elbow_results_{net_name}_dim_{rep_dim}"
+ ("_kfold" if k_fold else "")
+ ".pkl"
)
results_path = Path(xp_path) / results_filename
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
with open(results_path, "wb") as f:
pickle.dump(dim_results, f)
2025-06-10 09:31:28 +02:00
2025-06-13 10:24:54 +02:00
logger.info(
f"Saved elbow test results for dimension {rep_dim} to {results_path}"
)
else:
logger.error(f"Unknown action: {action}")
2024-06-28 07:42:12 +02:00
2024-06-28 11:36:46 +02:00
if __name__ == "__main__":
2024-06-28 07:42:12 +02:00
main()