from __future__ import annotations import shutil from datetime import datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np import polars as pl from matplotlib.lines import Line2D # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_comparisons") SEMI_LABELING_REGIMES = [(0, 0), (50, 10), (500, 100)] # Semi-supervised setting to select SEMI_NORMALS = 50 SEMI_ANOMALOUS = 10 # Which evaluation columns to plot EVALS = ["exp_based", "manual_based"] # Latent dimensions to show as 7 subplots LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024] # Interpolation grids ROC_GRID = np.linspace(0.0, 1.0, 200) PRC_GRID = np.linspace(0.0, 1.0, 200) # ---------------------------- # Helpers # ---------------------------- def canonicalize_network(name: str) -> str: """Map net_name strings to clean labels for plotting.""" low = (name or "").lower() if "lenet" in low: return "LeNet" if "efficient" in low: return "Efficient" return name or "unknown" def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray): """ Interpolate a list of (x, y) curves onto a common grid. Returns mean_y, std_y on the grid. Skips empty or invalid curves. """ if not curves: return np.full_like(grid, np.nan, dtype=float), np.full_like( grid, np.nan, dtype=float ) interps = [] for x, y in curves: if x is None or y is None: continue x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) if x.size == 0 or y.size == 0 or x.size != y.size: continue # ensure sorted by x and unique order = np.argsort(x) x = x[order] y = y[order] # deduplicate x (np.interp requires ascending x) uniq_x, uniq_idx = np.unique(x, return_index=True) y = y[uniq_idx] x = uniq_x # bound grid to valid interp range gmin = max(grid[0], x[0]) gmax = min(grid[-1], x[-1]) g = np.clip(grid, gmin, gmax) yi = np.interp(g, x, y) interps.append(yi) if not interps: return np.full_like(grid, np.nan, dtype=float), np.full_like( grid, np.nan, dtype=float ) A = np.vstack(interps) return np.nanmean(A, axis=0), np.nanstd(A, axis=0) def _net_label_col(df: pl.DataFrame) -> pl.DataFrame: """Add 'net_label' column (LeNet/Efficient/fallback).""" return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def _select_rows( df: pl.DataFrame, *, model: str, eval_type: str, latent_dim: int, net_label: str | None, semi_normals: int, semi_anomalous: int, ) -> pl.DataFrame: """Polars filter: by model/eval/latent and optionally net_label.""" exprs = [ pl.col("model") == model, pl.col("eval") == eval_type, pl.col("latent_dim") == latent_dim, pl.col("semi_normals") == semi_normals, pl.col("semi_anomalous") == semi_anomalous, ] if net_label is not None: exprs.append(pl.col("net_label") == net_label) return df.filter(pl.all_horizontal(exprs)) def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]: """ From a list of rows (Python dicts), return list of (x, y) curves for given kind. kind: "roc" or "prc" """ curves = [] for r in rows: if kind == "roc": c = r.get("roc_curve") if not c: continue x, y = c.get("fpr"), c.get("tpr") else: c = r.get("prc_curve") if not c: continue x, y = c.get("recall"), c.get("precision") if x is None or y is None: continue curves.append((np.asarray(x, dtype=float), np.asarray(y, dtype=float))) return curves def _ensure_dim_axes(fig_title: str): """Return figure, axes array laid out 2x4; last axis is for legend.""" fig, axes = plt.subplots( nrows=4, ncols=2, figsize=(12, 16), constrained_layout=True ) fig.suptitle(fig_title, fontsize=14) axes = axes.ravel() return fig, axes def _add_legend_to_axis(ax, handles_labels): ax.axis("off") handles, labels = handles_labels ax.legend( handles, labels, loc="center", frameon=False, ncol=1, fontsize=11, borderaxespad=0.5, ) def plot_grid_from_df( df: pl.DataFrame, eval_type: str, kind: str, semi_normals: int, semi_anomalous: int, out_path: Path, ): """ Create a 2x4 grid of subplots, one per latent dim; 8th panel holds legend. kind: 'roc' or 'prc' """ fig_title = f"{kind.upper()} — {eval_type} (semi = {semi_normals}/{semi_anomalous})" fig, axes = _ensure_dim_axes(fig_title) # plotting order & colors series = [ ( "isoforest", None, "IsolationForest", "tab:purple", ), # baselines from Efficient only (handled below) ("ocsvm", None, "OC-SVM", "tab:green"), ("deepsad", "LeNet", "DeepSAD (LeNet)", "tab:blue"), ("deepsad", "Efficient", "DeepSAD (Efficient)", "tab:orange"), ] # Handles for legend (build from first subplot that has data) legend_handles = [] legend_labels = [] have_legend = False for i, dim in enumerate(LATENT_DIMS): if i >= 7: break # last slot reserved for legend ax = axes[i] ax.set_title(f"latent_dim = {dim}") ax.grid(True, alpha=0.3) if kind == "roc": ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xlabel("FPR") ax.set_ylabel("TPR") grid = ROC_GRID else: ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xlabel("Recall") ax.set_ylabel("Precision") grid = PRC_GRID plotted_any = False for model, net_needed, label, color in series: # baselines: use Efficient only net_filter = net_needed if model in ("isoforest", "ocsvm"): net_filter = "Efficient" sub = _select_rows( df, model=model, eval_type=eval_type, latent_dim=dim, net_label=net_filter, semi_normals=semi_normals, semi_anomalous=semi_anomalous, ) if sub.height == 0: continue rows = sub.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts() curves = _extract_curves(rows, kind) if not curves: continue mean_y, std_y = _interp_mean_std(curves, grid) # Guard for all-NaN if np.all(np.isnan(mean_y)): continue ax.plot(grid, mean_y, label=label, color=color) ax.fill_between( grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color ) plotted_any = True if not have_legend: legend_handles.append(Line2D([0], [0], color=color, lw=2)) legend_labels.append(label) if not plotted_any: ax.text( 0.5, 0.5, "No data", ha="center", va="center", fontsize=12, alpha=0.7 ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) if not have_legend and legend_handles: have_legend = True # Legend in 8th slot _add_legend_to_axis(axes[7], (legend_handles, legend_labels)) # Save out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) def main(): # Load main results DF (uses your cache if enabled in the loader) df = load_results_dataframe(ROOT, allow_cache=True) # Add clean network labels complete_df = _net_label_col(df) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive_dir = OUTPUT_DIR / "archive" archive_dir.mkdir(parents=True, exist_ok=True) ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES: # Restrict to our semi-supervised setting df = complete_df.filter( (pl.col("semi_normals") == semi_normals) & (pl.col("semi_anomalous") == semi_anomalous) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) & (pl.col("eval").is_in(EVALS)) & (pl.col("latent_dim").is_in(LATENT_DIMS)) ) # Plot 4 figures for eval_type in EVALS: # ROC plot_grid_from_df( df, eval_type=eval_type, kind="roc", semi_normals=semi_normals, semi_anomalous=semi_anomalous, out_path=ts_dir / f"roc_semi_{semi_normals}_{semi_anomalous}_{eval_type}.png", ) # PRC plot_grid_from_df( df, eval_type=eval_type, kind="prc", semi_normals=semi_normals, semi_anomalous=semi_anomalous, out_path=ts_dir / f"prc_{semi_normals}_{semi_anomalous}_{eval_type}.png", ) # Copy this script to preserve the code used for the outputs script_path = Path(__file__) shutil.copy2(script_path, ts_dir) # Mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(exist_ok=True, parents=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved plots to: {ts_dir}") print(f"Also updated: {latest}") if __name__ == "__main__": main()