270 lines
8.5 KiB
Python
270 lines
8.5 KiB
Python
# ae_elbow_from_df.py
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import polars as pl
|
|
|
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
|
from load_results import load_pretraining_results_dataframe
|
|
|
|
# ----------------------------
|
|
# Config
|
|
# ----------------------------
|
|
ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader
|
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/ae_elbow_lenet_from_df")
|
|
|
|
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
|
|
LABEL_FIELD = "labels_exp_based"
|
|
|
|
|
|
# ----------------------------
|
|
# Helpers
|
|
# ----------------------------
|
|
def canonicalize_network(name: str) -> str:
|
|
"""Map various net_name strings to clean labels for plotting."""
|
|
low = (name or "").lower()
|
|
if "lenet" in low:
|
|
return "LeNet"
|
|
if "efficient" in low:
|
|
return "Efficient"
|
|
# fallback: show whatever was stored
|
|
return name or "unknown"
|
|
|
|
|
|
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
|
|
"""Mean of per-batch means (matches how the original test loss was computed)."""
|
|
n = len(scores)
|
|
if n == 0:
|
|
return np.nan
|
|
if batch_size <= 0:
|
|
batch_size = n # single batch fallback
|
|
n_batches = (n + batch_size - 1) // batch_size
|
|
acc = 0.0
|
|
for i in range(0, n, batch_size):
|
|
acc += float(np.mean(scores[i : i + batch_size]))
|
|
return acc / n_batches
|
|
|
|
|
|
def extract_batch_size(cfg_json: str) -> int:
|
|
"""
|
|
Prefer AE batch size; fall back to general batch_size; then a safe default.
|
|
We only rely on config_json (no lifted fields).
|
|
"""
|
|
try:
|
|
cfg = json.loads(cfg_json) if cfg_json else {}
|
|
except Exception:
|
|
cfg = {}
|
|
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
|
|
|
|
|
|
def build_arch_curves_from_df(
|
|
df: pl.DataFrame,
|
|
label_field: str = "labels_exp_based",
|
|
only_nets: set[str] | None = None,
|
|
):
|
|
"""
|
|
From the AE pretraining DF, compute (dims, means, stds) for normal/anomaly/overall
|
|
grouped by network and latent_dim. Returns:
|
|
{ net_label: {
|
|
"normal": (dims, means, stds),
|
|
"anomaly": (dims, means, stds),
|
|
"overall": (dims, means, stds),
|
|
} }
|
|
"""
|
|
if "split" not in df.columns:
|
|
raise ValueError("Expected 'split' column in AE dataframe.")
|
|
if "scores" not in df.columns:
|
|
raise ValueError("Expected 'scores' column in AE dataframe.")
|
|
if "network" not in df.columns or "latent_dim" not in df.columns:
|
|
raise ValueError("Expected 'network' and 'latent_dim' columns in AE dataframe.")
|
|
if label_field not in df.columns:
|
|
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
|
|
|
|
# Keep only test split
|
|
df = df.filter(pl.col("split") == "test")
|
|
|
|
groups: dict[tuple[str, int], dict[str, list[float]]] = {}
|
|
|
|
for row in df.iter_rows(named=True):
|
|
net_label = canonicalize_network(row["network"])
|
|
if only_nets and net_label not in only_nets:
|
|
continue
|
|
|
|
dim = int(row["latent_dim"])
|
|
batch_size = extract_batch_size(row.get("config_json"))
|
|
scores = np.asarray(row["scores"] or [], dtype=float)
|
|
|
|
labels = row.get(label_field)
|
|
labels = np.asarray(labels, dtype=int) if labels is not None else None
|
|
|
|
overall_loss = calculate_batch_mean_loss(scores, batch_size)
|
|
|
|
# Split by labels if available; otherwise we only aggregate overall
|
|
normal_loss = np.nan
|
|
anomaly_loss = np.nan
|
|
if labels is not None and labels.size == scores.size:
|
|
normal_scores = scores[labels == 1]
|
|
anomaly_scores = scores[labels == -1]
|
|
if normal_scores.size > 0:
|
|
normal_loss = calculate_batch_mean_loss(normal_scores, batch_size)
|
|
if anomaly_scores.size > 0:
|
|
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
|
|
|
|
key = (net_label, dim)
|
|
if key not in groups:
|
|
groups[key] = {"normal": [], "anomaly": [], "overall": []}
|
|
groups[key]["overall"].append(overall_loss)
|
|
groups[key]["normal"].append(normal_loss)
|
|
groups[key]["anomaly"].append(anomaly_loss)
|
|
|
|
# Aggregate across folds -> per (net, dim) mean/std
|
|
per_net_dims: dict[str, set[int]] = {}
|
|
for net, dim in groups:
|
|
per_net_dims.setdefault(net, set()).add(dim)
|
|
|
|
result: dict[str, dict[str, tuple[list[int], list[float], list[float]]]] = {}
|
|
for net, dims in per_net_dims.items():
|
|
dims_sorted = sorted(dims)
|
|
|
|
def collect(kind: str):
|
|
means, stds = [], []
|
|
for d in dims_sorted:
|
|
xs = [
|
|
x
|
|
for (n2, d2), v in groups.items()
|
|
if n2 == net and d2 == d
|
|
for x in v[kind]
|
|
if x is not None and not np.isnan(x)
|
|
]
|
|
if len(xs) == 0:
|
|
means.append(np.nan)
|
|
stds.append(np.nan)
|
|
else:
|
|
means.append(float(np.mean(xs)))
|
|
stds.append(float(np.std(xs)))
|
|
return dims_sorted, means, stds
|
|
|
|
result[net] = {
|
|
"normal": collect("normal"),
|
|
"anomaly": collect("anomaly"),
|
|
"overall": collect("overall"),
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
|
"""
|
|
arch_results: {arch_name: (dims, means, stds)}
|
|
"""
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# default color map if not provided
|
|
if colors is None:
|
|
colors = {
|
|
"LeNet": "tab:blue",
|
|
"Efficient": "tab:orange",
|
|
}
|
|
|
|
# Get unique dimensions across all architectures
|
|
all_dims = sorted(
|
|
set(dim for _, (dims, _, _) in arch_results.items() for dim in dims)
|
|
)
|
|
|
|
for arch_name, (dims, means, stds) in arch_results.items():
|
|
color = colors.get(arch_name)
|
|
# Plot line
|
|
if color is None:
|
|
plt.plot(dims, means, marker="o", label=arch_name)
|
|
plt.fill_between(
|
|
dims,
|
|
np.array(means) - np.array(stds),
|
|
np.array(means) + np.array(stds),
|
|
alpha=0.2,
|
|
)
|
|
else:
|
|
plt.plot(dims, means, marker="o", color=color, label=arch_name)
|
|
plt.fill_between(
|
|
dims,
|
|
np.array(means) - np.array(stds),
|
|
np.array(means) + np.array(stds),
|
|
color=color,
|
|
alpha=0.2,
|
|
)
|
|
|
|
plt.xlabel("Latent Dimensionality")
|
|
plt.ylabel("Test Loss")
|
|
plt.title(title)
|
|
plt.legend()
|
|
plt.grid(True, alpha=0.3)
|
|
plt.xticks(all_dims)
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def main():
|
|
# Load AE DF (uses your cache if enabled in the loader)
|
|
df = load_pretraining_results_dataframe(ROOT, allow_cache=True, include_train=False)
|
|
|
|
# Optional: filter to just LeNet vs Efficient; drop this set() to plot all nets
|
|
wanted_nets = {"LeNet", "Efficient"}
|
|
|
|
curves = build_arch_curves_from_df(
|
|
df,
|
|
label_field=LABEL_FIELD,
|
|
only_nets=wanted_nets,
|
|
)
|
|
|
|
# Prepare output dirs
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def pick(kind: str):
|
|
# kind in {"normal","anomaly","overall"}
|
|
return {name: payload[kind] for name, payload in curves.items()}
|
|
|
|
plot_multi_loss_curve(
|
|
pick("normal"),
|
|
"Normal Class Test Loss vs. Latent Dimensionality",
|
|
ts_dir / "ae_elbow_test_loss_normal.png",
|
|
)
|
|
|
|
plot_multi_loss_curve(
|
|
pick("anomaly"),
|
|
"Anomaly Class Test Loss vs. Latent Dimensionality",
|
|
ts_dir / "ae_elbow_test_loss_anomaly.png",
|
|
)
|
|
|
|
plot_multi_loss_curve(
|
|
pick("overall"),
|
|
"Overall Test Loss vs. Latent Dimensionality",
|
|
ts_dir / "ae_elbow_test_loss_overall.png",
|
|
)
|
|
|
|
# Copy this script to preserve the code used for the outputs
|
|
script_path = Path(__file__)
|
|
shutil.copy2(script_path, ts_dir)
|
|
|
|
# Optionally mirror latest
|
|
latest = OUTPUT_DIR / "latest"
|
|
latest.mkdir(exist_ok=True, parents=True)
|
|
for f in ts_dir.iterdir():
|
|
if f.is_file():
|
|
shutil.copy2(f, latest / f.name)
|
|
|
|
print(f"Saved plots to: {ts_dir}")
|
|
print(f"Also updated: {latest}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|