# curves_2x1_by_net_with_regimes_from_df.py from __future__ import annotations import shutil from datetime import datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np import polars as pl from matplotlib.lines import Line2D from scipy.stats import sem, t # CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT from load_results import load_results_dataframe # --------------------------------- # Config # --------------------------------- ROOT = Path("/home/fedex/mt/results/copy") OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison") LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024] SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)] EVALS = ["exp_based", "manual_based"] # Interp grids ROC_GRID = np.linspace(0.0, 1.0, 200) PRC_GRID = np.linspace(0.0, 1.0, 200) # Baselines are duplicated across nets; use Efficient-only to avoid repetition BASELINE_NET = "Efficient" # Colors/styles COLOR_BASELINES = { "isoforest": "tab:purple", "ocsvm": "tab:green", } COLOR_REGIMES = { (0, 0): "tab:blue", (50, 10): "tab:orange", (500, 100): "tab:red", } LINESTYLES = { (0, 0): "-", (50, 10): "--", (500, 100): "-.", } # --------------------------------- # Helpers # --------------------------------- def _net_label_col(df: pl.DataFrame) -> pl.DataFrame: return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def mean_ci(values: list[float], confidence: float = 0.95) -> tuple[float, float]: """Return mean and half-width of the (approx) confidence interval. Robust to n<2.""" arr = np.asarray([v for v in values if v is not None], dtype=float) if arr.size == 0: return np.nan, np.nan if arr.size == 1: return float(arr[0]), 0.0 m = float(arr.mean()) s = sem(arr, nan_policy="omit") h = s * t.ppf((1 + confidence) / 2.0, arr.size - 1) return m, float(h) def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray): """Interpolate many (x,y) onto grid and return mean±std; robust to duplicates/empty.""" if not curves: return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float) interps = [] for x, y in curves: if x is None or y is None: continue x = np.asarray(x, float) y = np.asarray(y, float) if x.size == 0 or y.size == 0 or x.size != y.size: continue order = np.argsort(x) x, y = x[order], y[order] x, uniq_idx = np.unique(x, return_index=True) y = y[uniq_idx] g = np.clip(grid, x[0], x[-1]) yi = np.interp(g, x, y) interps.append(yi) if not interps: return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float) A = np.vstack(interps) return np.nanmean(A, axis=0), np.nanstd(A, axis=0) def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]: curves = [] for r in rows: if kind == "roc": c = r.get("roc_curve") if not c: continue x, y = c.get("fpr"), c.get("tpr") else: c = r.get("prc_curve") if not c: continue x, y = c.get("recall"), c.get("precision") if x is None or y is None: continue curves.append((np.asarray(x, float), np.asarray(y, float))) return curves def _select_rows( df: pl.DataFrame, *, model: str, eval_type: str, latent_dim: int, semi_normals: int | None = None, semi_anomalous: int | None = None, net_label: str | None = None, ) -> pl.DataFrame: exprs = [ pl.col("model") == model, pl.col("eval") == eval_type, pl.col("latent_dim") == latent_dim, ] if semi_normals is not None: exprs.append(pl.col("semi_normals") == semi_normals) if semi_anomalous is not None: exprs.append(pl.col("semi_anomalous") == semi_anomalous) if net_label is not None: exprs.append(pl.col("net_label") == net_label) return df.filter(pl.all_horizontal(exprs)) def _auc_list(sub: pl.DataFrame) -> list[float]: return [x for x in sub.select("auc").to_series().to_list() if x is not None] def _ap_list(sub: pl.DataFrame) -> list[float]: return [x for x in sub.select("ap").to_series().to_list() if x is not None] def _plot_panel( ax, df: pl.DataFrame, *, eval_type: str, net_for_deepsad: str, latent_dim: int, kind: str, ): """ Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient). Legend entries include mean±CI of AUC/AP. """ ax.grid(True, alpha=0.3) ax.set_xlim(0, 1) ax.set_ylim(0, 1) if kind == "roc": ax.set_xlabel("FPR") ax.set_ylabel("TPR") grid = ROC_GRID else: ax.set_xlabel("Recall") ax.set_ylabel("Precision") grid = PRC_GRID handles, labels = [], [] # ----- Baselines (Efficient) for model in ("isoforest", "ocsvm"): sub_b = _select_rows( df, model=model, eval_type=eval_type, latent_dim=latent_dim, net_label=BASELINE_NET, ) if sub_b.height == 0: continue rows = sub_b.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts() curves = _extract_curves(rows, kind) mean_y, std_y = _interp_mean_std(curves, grid) if np.all(np.isnan(mean_y)): continue # Metric for legend metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b) m, ci = mean_ci(metric_vals) lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})" color = COLOR_BASELINES[model] h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0] ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color) handles.append(h) labels.append(lab) # ----- DeepSAD (this panel's net) across semi-regimes for regime in SEMI_REGIMES: sn, sa = regime sub_d = _select_rows( df, model="deepsad", eval_type=eval_type, latent_dim=latent_dim, semi_normals=sn, semi_anomalous=sa, net_label=net_for_deepsad, ) if sub_d.height == 0: continue rows = sub_d.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts() curves = _extract_curves(rows, kind) mean_y, std_y = _interp_mean_std(curves, grid) if np.all(np.isnan(mean_y)): continue metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d) m, ci = mean_ci(metric_vals) lab = f"DeepSAD {net_for_deepsad} — semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})" color = COLOR_REGIMES[regime] ls = LINESTYLES[regime] h = ax.plot(grid, mean_y, lw=2, color=color, linestyle=ls, label=lab)[0] ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color) handles.append(h) labels.append(lab) # Chance line for ROC if kind == "roc": ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance") # Legend ax.legend(loc="lower right", fontsize=9, frameon=True) def make_figures_for_dim( df: pl.DataFrame, eval_type: str, latent_dim: int, out_dir: Path ): # ROC: 2×1 fig_roc, axes = plt.subplots( nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True ) fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14) _plot_panel( axes[0], df, eval_type=eval_type, net_for_deepsad="LeNet", latent_dim=latent_dim, kind="roc", ) axes[0].set_title("DeepSAD (LeNet) + baselines") _plot_panel( axes[1], df, eval_type=eval_type, net_for_deepsad="Efficient", latent_dim=latent_dim, kind="roc", ) axes[1].set_title("DeepSAD (Efficient) + baselines") out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png" fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight") plt.close(fig_roc) # PRC: 2×1 fig_prc, axes = plt.subplots( nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True ) fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14) _plot_panel( axes[0], df, eval_type=eval_type, net_for_deepsad="LeNet", latent_dim=latent_dim, kind="prc", ) axes[0].set_title("DeepSAD (LeNet) + baselines") _plot_panel( axes[1], df, eval_type=eval_type, net_for_deepsad="Efficient", latent_dim=latent_dim, kind="prc", ) axes[1].set_title("DeepSAD (Efficient) + baselines") out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png" fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight") plt.close(fig_prc) def main(): # Load dataframe and prep df = load_results_dataframe(ROOT, allow_cache=True) df = _net_label_col(df) # Filter to relevant models/evals only once df = df.filter( (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) & (pl.col("eval").is_in(EVALS)) ) # Output/archiving like your AE script OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive = OUTPUT_DIR / "archive" archive.mkdir(parents=True, exist_ok=True) ts_dir = archive / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) # Generate figures for eval_type in EVALS: for dim in LATENT_DIMS: make_figures_for_dim( df, eval_type=eval_type, latent_dim=dim, out_dir=ts_dir ) # Copy this script for provenance script_path = Path(__file__) try: shutil.copy2(script_path, ts_dir) except Exception: pass # best effort if running in environments where __file__ may not exist # Update "latest" latest = OUTPUT_DIR / "latest" latest.mkdir(parents=True, exist_ok=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved plots to: {ts_dir}") print(f"Also updated: {latest}") if __name__ == "__main__": main()