175 lines
5.1 KiB
Python
175 lines
5.1 KiB
Python
|
|
# loads results from autoencoder training form a pickle file and evaluates results and visualizes them to find traiing elbow
|
||
|
|
|
||
|
|
import pickle
|
||
|
|
import unittest
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
results_folder = Path(
|
||
|
|
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/test/DeepSAD/subter_ae_elbow/"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Find all result files matching the pattern
|
||
|
|
result_files = sorted(
|
||
|
|
results_folder.glob("ae_elbow_results_subter_LeNet_dim_*_kfold.pkl")
|
||
|
|
)
|
||
|
|
|
||
|
|
# Initialize data structures for both classes
|
||
|
|
dimensions = []
|
||
|
|
normal_means = []
|
||
|
|
normal_stds = []
|
||
|
|
anomaly_means = []
|
||
|
|
anomaly_stds = []
|
||
|
|
|
||
|
|
BATCH_SIZE = 256 # Add this constant at the top of the file
|
||
|
|
|
||
|
|
|
||
|
|
def calculate_batch_mean_loss(scores, batch_size=BATCH_SIZE):
|
||
|
|
"""Calculate mean loss over batches similar to the original testing code."""
|
||
|
|
n_samples = len(scores)
|
||
|
|
n_batches = (n_samples + batch_size - 1) // batch_size # ceiling division
|
||
|
|
|
||
|
|
# Split scores into batches
|
||
|
|
batch_losses = []
|
||
|
|
for i in range(0, n_samples, batch_size):
|
||
|
|
batch_scores = scores[i : i + batch_size]
|
||
|
|
batch_losses.append(np.mean(batch_scores))
|
||
|
|
|
||
|
|
return np.sum(batch_losses) / n_batches
|
||
|
|
|
||
|
|
|
||
|
|
def test_loss_calculation(results: Dict, batch_size: int = BATCH_SIZE) -> None:
|
||
|
|
"""Test if our loss calculation matches the original implementation."""
|
||
|
|
test = unittest.TestCase()
|
||
|
|
folds = results["ae_results"]
|
||
|
|
dim = results["dimension"]
|
||
|
|
|
||
|
|
for fold_key in folds:
|
||
|
|
fold_data = folds[fold_key]["test"]
|
||
|
|
scores = np.array(fold_data["scores"])
|
||
|
|
original_loss = fold_data["loss"]
|
||
|
|
calculated_loss = calculate_batch_mean_loss(scores)
|
||
|
|
|
||
|
|
try:
|
||
|
|
test.assertAlmostEqual(
|
||
|
|
original_loss,
|
||
|
|
calculated_loss,
|
||
|
|
places=5,
|
||
|
|
msg=f"Loss mismatch for dim={dim}, {fold_key}",
|
||
|
|
)
|
||
|
|
except AssertionError as e:
|
||
|
|
print(f"Warning: {str(e)}")
|
||
|
|
print(f"Original: {original_loss:.6f}, Calculated: {calculated_loss:.6f}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
# Load and verify data
|
||
|
|
print("Verifying loss calculation implementation...")
|
||
|
|
for result_file in result_files:
|
||
|
|
with open(result_file, "rb") as f:
|
||
|
|
results = pickle.load(f)
|
||
|
|
test_loss_calculation(results)
|
||
|
|
print("Loss calculation verified successfully!")
|
||
|
|
|
||
|
|
# Continue with actual data processing
|
||
|
|
for result_file in result_files:
|
||
|
|
with open(result_file, "rb") as f:
|
||
|
|
results = pickle.load(f)
|
||
|
|
dim = int(results["dimension"])
|
||
|
|
folds = results["ae_results"]
|
||
|
|
|
||
|
|
normal_fold_losses = []
|
||
|
|
anomaly_fold_losses = []
|
||
|
|
|
||
|
|
for fold_key in folds:
|
||
|
|
fold_data = folds[fold_key]["test"]
|
||
|
|
scores = np.array(fold_data["scores"])
|
||
|
|
labels = np.array(fold_data["labels_exp_based"])
|
||
|
|
|
||
|
|
# Calculate mean loss for normal and anomaly samples
|
||
|
|
normal_scores = scores[labels == 1]
|
||
|
|
anomaly_scores = scores[labels == -1]
|
||
|
|
|
||
|
|
# Calculate losses using batch means
|
||
|
|
normal_fold_losses.append(calculate_batch_mean_loss(normal_scores))
|
||
|
|
anomaly_fold_losses.append(calculate_batch_mean_loss(anomaly_scores))
|
||
|
|
|
||
|
|
dimensions.append(dim)
|
||
|
|
normal_means.append(np.mean(normal_fold_losses))
|
||
|
|
normal_stds.append(np.std(normal_fold_losses))
|
||
|
|
anomaly_means.append(np.mean(anomaly_fold_losses))
|
||
|
|
anomaly_stds.append(np.std(anomaly_fold_losses))
|
||
|
|
|
||
|
|
# Sort by dimension
|
||
|
|
dims, n_means, n_stds, a_means, a_stds = zip(
|
||
|
|
*sorted(zip(dimensions, normal_means, normal_stds, anomaly_means, anomaly_stds))
|
||
|
|
)
|
||
|
|
|
||
|
|
# Calculate overall means and stds
|
||
|
|
means = [(n + a) / 2 for n, a in zip(n_means, a_means)]
|
||
|
|
stds = [(ns + as_) / 2 for ns, as_ in zip(n_stds, a_stds)]
|
||
|
|
|
||
|
|
|
||
|
|
def plot_loss_curve(dims, means, stds, title, color, output_path):
|
||
|
|
"""Create and save a single loss curve plot.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
dims: List of latent dimensions
|
||
|
|
means: List of mean losses
|
||
|
|
stds: List of standard deviations
|
||
|
|
title: Plot title
|
||
|
|
color: Color for plot and fill
|
||
|
|
output_path: Where to save the plot
|
||
|
|
"""
|
||
|
|
plt.figure(figsize=(8, 5))
|
||
|
|
plt.plot(dims, means, marker="o", color=color, label="Mean Test Loss")
|
||
|
|
plt.fill_between(
|
||
|
|
dims,
|
||
|
|
np.array(means) - np.array(stds),
|
||
|
|
np.array(means) + np.array(stds),
|
||
|
|
color=color,
|
||
|
|
alpha=0.2,
|
||
|
|
label="Std Dev",
|
||
|
|
)
|
||
|
|
plt.xlabel("Latent Dimension")
|
||
|
|
plt.ylabel("Test Loss")
|
||
|
|
plt.title(title)
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True, alpha=0.3)
|
||
|
|
plt.xticks(dims) # Set x-ticks exactly at all data points
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
|
||
|
|
# Create the three plots
|
||
|
|
plot_loss_curve(
|
||
|
|
dims,
|
||
|
|
means,
|
||
|
|
stds,
|
||
|
|
"Overall Test Loss vs. Latent Dimension",
|
||
|
|
"blue",
|
||
|
|
results_folder / "ae_elbow_test_loss_overall.png",
|
||
|
|
)
|
||
|
|
|
||
|
|
plot_loss_curve(
|
||
|
|
dims,
|
||
|
|
n_means,
|
||
|
|
n_stds,
|
||
|
|
"Normal Class Test Loss vs. Latent Dimension",
|
||
|
|
"green",
|
||
|
|
results_folder / "ae_elbow_test_loss_normal.png",
|
||
|
|
)
|
||
|
|
|
||
|
|
plot_loss_curve(
|
||
|
|
dims,
|
||
|
|
a_means,
|
||
|
|
a_stds,
|
||
|
|
"Anomaly Class Test Loss vs. Latent Dimension",
|
||
|
|
"red",
|
||
|
|
results_folder / "ae_elbow_test_loss_anomaly.png",
|
||
|
|
)
|