save ocsvm and isoforest models during training

This commit is contained in:
Jan Kowalczyk
2025-06-20 09:16:08 +02:00
parent bbd093da0c
commit c552173cb2
3 changed files with 27 additions and 132 deletions

View File

@@ -490,14 +490,14 @@ def main(
# Save pretraining results
if fold_idx is None:
deepSAD.save_ae_results(export_pkl=xp_path + "/ae_results.pkl")
ae_model_path = xp_path + "/ae_model.tar"
deepSAD.save_ae_results(export_pkl=xp_path + "/results_ae.pkl")
ae_model_path = xp_path + "/model_ae.tar"
deepSAD.save_model(export_model=ae_model_path, save_ae=True)
else:
deepSAD.save_ae_results(
export_pkl=xp_path + f"/ae_results_{fold_idx}.pkl"
export_pkl=xp_path + f"/results_ae_{fold_idx}.pkl"
)
ae_model_path = xp_path + f"/ae_model_{fold_idx}.tar"
ae_model_path = xp_path + f"/model_ae_{fold_idx}.tar"
deepSAD.save_model(export_model=ae_model_path, save_ae=True)
# Initialize OC-SVM model (after pretraining to use autoencoder features)
@@ -593,150 +593,41 @@ def main(
# Save results, model, and configuration
if fold_idx is None:
if train_deepsad:
deepSAD.save_results(export_pkl=xp_path + "/results.pkl")
deepSAD.save_model(export_model=xp_path + "/model.tar")
deepSAD.save_results(export_pkl=xp_path + "/results_deepsad.pkl")
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
if train_ocsvm:
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin")
if train_isoforest:
Isoforest.save_results(
export_pkl=xp_path + "/results_isoforest.pkl"
)
Isoforest.save_model(export_path=xp_path + "/model_isoforest.pkl")
else:
if train_deepsad:
deepSAD.save_results(
export_pkl=xp_path + f"/results_{fold_idx}.pkl"
export_pkl=xp_path + f"/results_deepsad_{fold_idx}.pkl"
)
deepSAD.save_model(
export_model=xp_path + f"/model_deepsad_{fold_idx}.tar"
)
deepSAD.save_model(export_model=xp_path + f"/model_{fold_idx}.tar")
if train_ocsvm:
ocsvm.save_results(
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
)
ocsvm.save_model(
export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin"
)
if train_isoforest:
Isoforest.save_results(
export_pkl=xp_path + f"/results_isoforest_{fold_idx}.pkl"
)
Isoforest.save_model(
export_path=xp_path + f"/model_isoforest_{fold_idx}.pkl"
)
cfg.save_config(export_json=xp_path + "/config.json")
# Plot most anomalous and most normal test samples
if train_deepsad:
# Use experiment-based scores for plotting
indices, labels, scores = zip(
*deepSAD.results["test"]["exp_based"]["scores"]
)
indices, labels, scores = (
np.array(indices),
np.array(labels),
np.array(scores),
)
# Filter out samples with unknown labels (0)
valid_mask = labels != 0
indices = indices[valid_mask]
labels = labels[valid_mask]
scores = scores[valid_mask]
# Convert labels from -1/1 to 0/1 for plotting
labels = (labels == -1).astype(int) # -1 (anomaly) → 1, 1 (normal) → 0
idx_all_sorted = indices[
np.argsort(scores)
] # from lowest to highest score
idx_normal_sorted = indices[labels == 0][
np.argsort(scores[labels == 0])
]
# Optionally plot manual-based results:
# indices_m, labels_m, scores_m = zip(*deepSAD.results["test"]["manual_based"]["scores"])
# ...same processing as above...
if dataset_name in (
"mnist",
"fmnist",
"cifar10",
"elpv",
):
if dataset_name in (
"mnist",
"fmnist",
"elpv",
):
X_all_low = dataset.test_set.data[
idx_all_sorted[:32], ...
].unsqueeze(1)
X_all_high = dataset.test_set.data[
idx_all_sorted[-32:], ...
].unsqueeze(1)
X_normal_low = dataset.test_set.data[
idx_normal_sorted[:32], ...
].unsqueeze(1)
X_normal_high = dataset.test_set.data[
idx_normal_sorted[-32:], ...
].unsqueeze(1)
if dataset_name == "cifar10":
X_all_low = torch.tensor(
np.transpose(
dataset.test_set.data[idx_all_sorted[:32], ...],
(0, 3, 1, 2),
)
)
X_all_high = torch.tensor(
np.transpose(
dataset.test_set.data[idx_all_sorted[-32:], ...],
(0, 3, 1, 2),
)
)
X_normal_low = torch.tensor(
np.transpose(
dataset.test_set.data[idx_normal_sorted[:32], ...],
(0, 3, 1, 2),
)
)
X_normal_high = torch.tensor(
np.transpose(
dataset.test_set.data[idx_normal_sorted[-32:], ...],
(0, 3, 1, 2),
)
)
if fold_idx is None:
plot_images_grid(
X_all_low, export_img=xp_path + "/all_low", padding=2
)
plot_images_grid(
X_all_high, export_img=xp_path + "/all_high", padding=2
)
plot_images_grid(
X_normal_low, export_img=xp_path + "/normals_low", padding=2
)
plot_images_grid(
X_normal_high,
export_img=xp_path + "/normals_high",
padding=2,
)
else:
plot_images_grid(
X_all_low,
export_img=xp_path + f"/all_low_{fold_idx}",
padding=2,
)
plot_images_grid(
X_all_high,
export_img=xp_path + f"/all_high_{fold_idx}",
padding=2,
)
plot_images_grid(
X_normal_low,
export_img=xp_path + f"/normals_low_{fold_idx}",
padding=2,
)
plot_images_grid(
X_normal_high,
export_img=xp_path + f"/normals_high_{fold_idx}",
padding=2,
)
elif action == "infer":
dataset = load_dataset(
dataset_name,