#!/usr/bin/env python3 from __future__ import annotations import shutil from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, List, Tuple import matplotlib.pyplot as plt import numpy as np import polars as pl from matplotlib.ticker import MaxNLocator # ========================= # Config # ========================= ROOT = Path("/home/fedex/mt/results/copy") OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_latent") # Labeling regimes (shown as separate subplots) SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] # Evaluations: separate figure per eval EVALS: list[str] = ["exp_based", "manual_based"] # X-axis (latent dims) LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] # Visual style FIGSIZE = (8, 8) # one tall figure with 3 compact subplots MARKERSIZE = 7 SCATTER_ALPHA = 0.95 LINEWIDTH = 2.0 TREND_LINEWIDTH = 2.2 BAND_ALPHA = 0.18 # Toggle: show ±1 std bands (k-fold variability) SHOW_STD_BANDS = True # <<< set to False to hide the bands # Colors for the two DeepSAD backbones COLOR_LENET = "#1f77b4" # blue COLOR_EFFICIENT = "#ff7f0e" # orange # ========================= # Loader # ========================= from load_results import load_results_dataframe # ========================= # Helpers # ========================= def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame: return df.filter( (pl.col("model") == "deepsad") & (pl.col("eval").is_in(EVALS)) & (pl.col("latent_dim").is_in(LATENT_DIMS)) & (pl.col("net_label").is_in(["LeNet", "Efficient"])) ).select( "eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous", "fold", "ap", ) @dataclass(frozen=True) class Agg: mean: float std: float def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]: out: Dict[Tuple[str, str, int, int, int], Agg] = {} gb = ( df.group_by( ["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"] ) .agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std")) .to_dicts() ) for row in gb: key = ( str(row["eval"]), str(row["net_label"]), int(row["latent_dim"]), int(row["semi_normals"]), int(row["semi_anomalous"]), ) m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan s = float(row["std"]) if row["std"] == row["std"] else np.nan out[key] = Agg(mean=m, std=s) return out def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]: if len(xs) < 2: return np.array(xs, dtype=float), np.array(ys, dtype=float) x = np.array(xs, dtype=float) y = np.array(ys, dtype=float) a, b = np.polyfit(x, y, 1) x_fit = np.linspace(x.min(), x.max(), 200) y_fit = a * x_fit + b return x_fit, y_fit def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]: vals = np.array(all_vals, dtype=float) errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals) valid = np.isfinite(vals) if not np.any(valid): return (0.0, 1.0) v, e = vals[valid], errs[valid] lo = np.min(v - e) hi = np.max(v + e) span = max(1e-3, hi - lo) pad = 0.08 * span y0 = max(0.0, lo - pad) y1 = min(1.0, hi + pad) if (y1 - y0) < 0.08: mid = 0.5 * (y0 + y1) y0 = max(0.0, mid - 0.04) y1 = min(1.0, mid + 0.04) return (float(y0), float(y1)) def _get_dim_mapping(dims: list[int]) -> dict[int, int]: """Map actual dimensions to evenly spaced positions (0, 1, 2, ...)""" return {dim: i for i, dim in enumerate(dims)} def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path): fig, axes = plt.subplots( len(SEMI_LABELING_REGIMES), 1, figsize=FIGSIZE, constrained_layout=True, sharex=True, ) if len(SEMI_LABELING_REGIMES) == 1: axes = [axes] # Create dimension mapping dim_mapping = _get_dim_mapping(LATENT_DIMS) for ax, regime in zip(axes, SEMI_LABELING_REGIMES): semi_n, semi_a = regime data = {} for net in ["LeNet", "Efficient"]: xs, ys, es = [], [], [] for dim in LATENT_DIMS: key = (ev, net, dim, semi_n, semi_a) if key in agg: xs.append( dim_mapping[dim] ) # Use mapped position instead of actual dim ys.append(agg[key].mean) es.append(agg[key].std) data[net] = (xs, ys, es) for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]: xs, ys, es = data[net] if not xs: continue # Set evenly spaced ticks with actual dimension labels ax.set_xticks(list(dim_mapping.values())) ax.set_xticklabels(LATENT_DIMS) ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) ax.scatter( xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)" ) x_fit, y_fit = _lin_trend(xs, ys) # Now using mapped positions ax.plot( x_fit, y_fit, color=color, linewidth=TREND_LINEWIDTH, label=f"{net} (trend)", ) if SHOW_STD_BANDS and es and np.any(np.isfinite(es)): ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0) yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0) ax.fill_between( xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0 ) all_vals, all_errs = [], [] for net in ["LeNet", "Efficient"]: _, ys, es = data[net] all_vals.extend(ys) all_errs.extend(es) y0, y1 = _dynamic_ylim(all_vals, all_errs) ax.set_ylim(y0, y1) ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11) ax.grid(True, alpha=0.35) axes[-1].set_xlabel("Latent dimension") for ax in axes: ax.set_ylabel("AP") handles, labels = axes[0].get_legend_handles_labels() fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97)) fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05) fname = f"ap_trends_{ev}.png" fig.savefig(outdir / fname, dpi=150) plt.close(fig) def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path): outdir.mkdir(parents=True, exist_ok=True) for ev in EVALS: plot_eval(ev, agg, outdir) def main(): df = load_results_dataframe(ROOT, allow_cache=True) df = _with_net_label(df) df = _filter_deepsad(df) agg = aggregate_ap(df) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive_dir = OUTPUT_DIR / "archive" archive_dir.mkdir(parents=True, exist_ok=True) ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) plot_all(agg, ts_dir) try: script_path = Path(__file__) shutil.copy2(script_path, ts_dir / script_path.name) except Exception: pass latest = OUTPUT_DIR / "latest" latest.mkdir(parents=True, exist_ok=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved plots to: {ts_dir}") print(f"Also updated: {latest}") if __name__ == "__main__": main()