Files
mt/tools/plot_scripts/ae_elbow_lenet.py

270 lines
8.5 KiB
Python
Raw Normal View History

# ae_elbow_from_df.py
from __future__ import annotations
import json
2025-08-13 14:17:12 +02:00
import shutil
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
2025-09-10 19:41:00 +02:00
from plot_scripts.load_results import load_pretraining_results_dataframe
# ----------------------------
# Config
# ----------------------------
ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader
OUTPUT_DIR = Path("/home/fedex/mt/plots/ae_elbow_lenet_from_df")
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
LABEL_FIELD = "labels_exp_based"
# ----------------------------
# Helpers
# ----------------------------
def canonicalize_network(name: str) -> str:
"""Map various net_name strings to clean labels for plotting."""
low = (name or "").lower()
if "lenet" in low:
return "LeNet"
if "efficient" in low:
return "Efficient"
# fallback: show whatever was stored
return name or "unknown"
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
"""Mean of per-batch means (matches how the original test loss was computed)."""
n = len(scores)
if n == 0:
return np.nan
if batch_size <= 0:
batch_size = n # single batch fallback
n_batches = (n + batch_size - 1) // batch_size
acc = 0.0
for i in range(0, n, batch_size):
acc += float(np.mean(scores[i : i + batch_size]))
return acc / n_batches
def extract_batch_size(cfg_json: str) -> int:
"""
Prefer AE batch size; fall back to general batch_size; then a safe default.
We only rely on config_json (no lifted fields).
"""
try:
cfg = json.loads(cfg_json) if cfg_json else {}
except Exception:
cfg = {}
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
def build_arch_curves_from_df(
df: pl.DataFrame,
label_field: str = "labels_exp_based",
only_nets: set[str] | None = None,
):
"""
From the AE pretraining DF, compute (dims, means, stds) for normal/anomaly/overall
grouped by network and latent_dim. Returns:
{ net_label: {
"normal": (dims, means, stds),
"anomaly": (dims, means, stds),
"overall": (dims, means, stds),
} }
"""
if "split" not in df.columns:
raise ValueError("Expected 'split' column in AE dataframe.")
if "scores" not in df.columns:
raise ValueError("Expected 'scores' column in AE dataframe.")
if "network" not in df.columns or "latent_dim" not in df.columns:
raise ValueError("Expected 'network' and 'latent_dim' columns in AE dataframe.")
if label_field not in df.columns:
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
# Keep only test split
df = df.filter(pl.col("split") == "test")
groups: dict[tuple[str, int], dict[str, list[float]]] = {}
for row in df.iter_rows(named=True):
net_label = canonicalize_network(row["network"])
if only_nets and net_label not in only_nets:
continue
dim = int(row["latent_dim"])
batch_size = extract_batch_size(row.get("config_json"))
scores = np.asarray(row["scores"] or [], dtype=float)
labels = row.get(label_field)
labels = np.asarray(labels, dtype=int) if labels is not None else None
overall_loss = calculate_batch_mean_loss(scores, batch_size)
# Split by labels if available; otherwise we only aggregate overall
normal_loss = np.nan
anomaly_loss = np.nan
if labels is not None and labels.size == scores.size:
normal_scores = scores[labels == 1]
anomaly_scores = scores[labels == -1]
if normal_scores.size > 0:
normal_loss = calculate_batch_mean_loss(normal_scores, batch_size)
if anomaly_scores.size > 0:
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
key = (net_label, dim)
if key not in groups:
groups[key] = {"normal": [], "anomaly": [], "overall": []}
groups[key]["overall"].append(overall_loss)
groups[key]["normal"].append(normal_loss)
groups[key]["anomaly"].append(anomaly_loss)
# Aggregate across folds -> per (net, dim) mean/std
per_net_dims: dict[str, set[int]] = {}
for net, dim in groups:
per_net_dims.setdefault(net, set()).add(dim)
result: dict[str, dict[str, tuple[list[int], list[float], list[float]]]] = {}
for net, dims in per_net_dims.items():
dims_sorted = sorted(dims)
def collect(kind: str):
means, stds = [], []
for d in dims_sorted:
xs = [
x
for (n2, d2), v in groups.items()
if n2 == net and d2 == d
for x in v[kind]
if x is not None and not np.isnan(x)
]
if len(xs) == 0:
means.append(np.nan)
stds.append(np.nan)
else:
means.append(float(np.mean(xs)))
stds.append(float(np.std(xs)))
return dims_sorted, means, stds
result[net] = {
"normal": collect("normal"),
"anomaly": collect("anomaly"),
"overall": collect("overall"),
}
2025-08-13 14:17:12 +02:00
return result
2025-08-13 14:17:12 +02:00
def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
"""
arch_results: {arch_name: (dims, means, stds)}
2025-08-13 14:17:12 +02:00
"""
plt.figure(figsize=(10, 6))
# default color map if not provided
2025-08-13 14:17:12 +02:00
if colors is None:
colors = {
"LeNet": "tab:blue",
"Efficient": "tab:orange",
2025-08-13 14:17:12 +02:00
}
# Get unique dimensions across all architectures
all_dims = sorted(
set(dim for _, (dims, _, _) in arch_results.items() for dim in dims)
)
for arch_name, (dims, means, stds) in arch_results.items():
color = colors.get(arch_name)
# Plot line
if color is None:
plt.plot(dims, means, marker="o", label=arch_name)
plt.fill_between(
dims,
np.array(means) - np.array(stds),
np.array(means) + np.array(stds),
alpha=0.2,
)
else:
plt.plot(dims, means, marker="o", color=color, label=arch_name)
plt.fill_between(
dims,
np.array(means) - np.array(stds),
np.array(means) + np.array(stds),
color=color,
alpha=0.2,
)
plt.xlabel("Latent Dimensionality")
2025-08-13 14:17:12 +02:00
plt.ylabel("Test Loss")
plt.title(title)
plt.legend()
plt.grid(True, alpha=0.3)
plt.xticks(all_dims)
2025-08-13 14:17:12 +02:00
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def main():
# Load AE DF (uses your cache if enabled in the loader)
2025-09-10 19:41:00 +02:00
df = load_pretraining_results_dataframe(ROOT, allow_cache=True)
# Optional: filter to just LeNet vs Efficient; drop this set() to plot all nets
wanted_nets = {"LeNet", "Efficient"}
curves = build_arch_curves_from_df(
df,
label_field=LABEL_FIELD,
only_nets=wanted_nets,
)
# Prepare output dirs
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
def pick(kind: str):
# kind in {"normal","anomaly","overall"}
return {name: payload[kind] for name, payload in curves.items()}
2025-08-13 14:17:12 +02:00
plot_multi_loss_curve(
pick("normal"),
"Normal Class Test Loss vs. Latent Dimensionality",
ts_dir / "ae_elbow_test_loss_normal.png",
2025-08-13 14:17:12 +02:00
)
plot_multi_loss_curve(
pick("anomaly"),
"Anomaly Class Test Loss vs. Latent Dimensionality",
ts_dir / "ae_elbow_test_loss_anomaly.png",
2025-08-13 14:17:12 +02:00
)
plot_multi_loss_curve(
pick("overall"),
"Overall Test Loss vs. Latent Dimensionality",
ts_dir / "ae_elbow_test_loss_overall.png",
2025-08-13 14:17:12 +02:00
)
# Copy this script to preserve the code used for the outputs
script_path = Path(__file__)
shutil.copy2(script_path, ts_dir)
2025-08-13 14:17:12 +02:00
# Optionally mirror latest
latest = OUTPUT_DIR / "latest"
latest.mkdir(exist_ok=True, parents=True)
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
2025-08-13 14:17:12 +02:00
print(f"Saved plots to: {ts_dir}")
print(f"Also updated: {latest}")
2025-08-13 14:17:12 +02:00
if __name__ == "__main__":
main()