# ae_elbow_from_df.py from __future__ import annotations import json import shutil from datetime import datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np import polars as pl # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_pretraining_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/ae_elbow_lenet_from_df") # Which label field to use from the DF; "labels_exp_based" or "labels_manual_based" LABEL_FIELD = "labels_exp_based" # ---------------------------- # Helpers # ---------------------------- def canonicalize_network(name: str) -> str: """Map various net_name strings to clean labels for plotting.""" low = (name or "").lower() if "lenet" in low: return "LeNet" if "efficient" in low: return "Efficient" # fallback: show whatever was stored return name or "unknown" def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float: """Mean of per-batch means (matches how the original test loss was computed).""" n = len(scores) if n == 0: return np.nan if batch_size <= 0: batch_size = n # single batch fallback n_batches = (n + batch_size - 1) // batch_size acc = 0.0 for i in range(0, n, batch_size): acc += float(np.mean(scores[i : i + batch_size])) return acc / n_batches def extract_batch_size(cfg_json: str) -> int: """ Prefer AE batch size; fall back to general batch_size; then a safe default. We only rely on config_json (no lifted fields). """ try: cfg = json.loads(cfg_json) if cfg_json else {} except Exception: cfg = {} return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256) def build_arch_curves_from_df( df: pl.DataFrame, label_field: str = "labels_exp_based", only_nets: set[str] | None = None, ): """ From the AE pretraining DF, compute (dims, means, stds) for normal/anomaly/overall grouped by network and latent_dim. Returns: { net_label: { "normal": (dims, means, stds), "anomaly": (dims, means, stds), "overall": (dims, means, stds), } } """ # if "split" not in df.columns: # raise ValueError("Expected 'split' column in AE dataframe.") if "scores" not in df.columns: raise ValueError("Expected 'scores' column in AE dataframe.") if "network" not in df.columns or "latent_dim" not in df.columns: raise ValueError("Expected 'network' and 'latent_dim' columns in AE dataframe.") if label_field not in df.columns: raise ValueError(f"Expected '{label_field}' column in AE dataframe.") # Keep only test split # df = df.filter(pl.col("split") == "test") groups: dict[tuple[str, int], dict[str, list[float]]] = {} for row in df.iter_rows(named=True): net_label = canonicalize_network(row["network"]) if only_nets and net_label not in only_nets: continue dim = int(row["latent_dim"]) batch_size = extract_batch_size(row.get("config_json")) scores = np.asarray(row["scores"] or [], dtype=float) labels = row.get(label_field) labels = np.asarray(labels, dtype=int) if labels is not None else None overall_loss = calculate_batch_mean_loss(scores, batch_size) # Split by labels if available; otherwise we only aggregate overall normal_loss = np.nan anomaly_loss = np.nan if labels is not None and labels.size == scores.size: normal_scores = scores[labels == 1] anomaly_scores = scores[labels == -1] if normal_scores.size > 0: normal_loss = calculate_batch_mean_loss(normal_scores, batch_size) if anomaly_scores.size > 0: anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size) key = (net_label, dim) if key not in groups: groups[key] = {"normal": [], "anomaly": [], "overall": []} groups[key]["overall"].append(overall_loss) groups[key]["normal"].append(normal_loss) groups[key]["anomaly"].append(anomaly_loss) # Aggregate across folds -> per (net, dim) mean/std per_net_dims: dict[str, set[int]] = {} for net, dim in groups: per_net_dims.setdefault(net, set()).add(dim) result: dict[str, dict[str, tuple[list[int], list[float], list[float]]]] = {} for net, dims in per_net_dims.items(): dims_sorted = sorted(dims) def collect(kind: str): means, stds = [], [] for d in dims_sorted: xs = [ x for (n2, d2), v in groups.items() if n2 == net and d2 == d for x in v[kind] if x is not None and not np.isnan(x) ] if len(xs) == 0: means.append(np.nan) stds.append(np.nan) else: means.append(float(np.mean(xs))) stds.append(float(np.std(xs))) return dims_sorted, means, stds result[net] = { "normal": collect("normal"), "anomaly": collect("anomaly"), "overall": collect("overall"), } return result def plot_multi_loss_curve(arch_results, title, output_path, colors=None): """ arch_results: {arch_name: (dims, means, stds)} """ plt.figure(figsize=(10, 6)) # default color map if not provided if colors is None: colors = { "LeNet": "tab:blue", "Efficient": "tab:orange", } # Get unique dimensions across all architectures all_dims = sorted( set(dim for _, (dims, _, _) in arch_results.items() for dim in dims) ) for arch_name, (dims, means, stds) in arch_results.items(): color = colors.get(arch_name) # Plot line if color is None: plt.plot(dims, means, marker="o", label=arch_name) plt.fill_between( dims, np.array(means) - np.array(stds), np.array(means) + np.array(stds), alpha=0.2, ) else: plt.plot(dims, means, marker="o", color=color, label=arch_name) plt.fill_between( dims, np.array(means) - np.array(stds), np.array(means) + np.array(stds), color=color, alpha=0.2, ) plt.xlabel("Latent Dimensionality") plt.ylabel("Test Loss") # plt.title(title) plt.legend() plt.grid(True, alpha=0.3) plt.xticks(all_dims) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def main(): # Load AE DF (uses your cache if enabled in the loader) df = load_pretraining_results_dataframe(ROOT, allow_cache=True) # Optional: filter to just LeNet vs Efficient; drop this set() to plot all nets wanted_nets = {"LeNet", "Efficient"} curves = build_arch_curves_from_df( df, label_field=LABEL_FIELD, only_nets=wanted_nets, ) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) def pick(kind: str): # kind in {"normal","anomaly","overall"} return {name: payload[kind] for name, payload in curves.items()} plot_multi_loss_curve( pick("normal"), "Normal Class Test Loss vs. Latent Dimensionality", ts_dir / "ae_elbow_test_loss_normal.png", ) plot_multi_loss_curve( pick("anomaly"), "Anomaly Class Test Loss vs. Latent Dimensionality", ts_dir / "ae_elbow_test_loss_anomaly.png", ) plot_multi_loss_curve( pick("overall"), "Overall Test Loss vs. Latent Dimensionality", ts_dir / "ae_elbow_test_loss_overall.png", ) # Copy this script to preserve the code used for the outputs script_path = Path(__file__) shutil.copy2(script_path, ts_dir) # Optionally mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(exist_ok=True, parents=True) for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved plots to: {ts_dir}") print(f"Also updated: {latest}") if __name__ == "__main__": main()