2025-09-03 14:55:54 +02:00
|
|
|
|
# curves_2x1_by_net_with_regimes_from_df.py
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
|
|
|
|
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
|
2025-09-27 16:34:52 +02:00
|
|
|
|
from load_results import load_results_dataframe
|
|
|
|
|
|
from matplotlib.lines import Line2D
|
|
|
|
|
|
from scipy.stats import sem, t
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
# ---------------------------------
|
|
|
|
|
|
# Config
|
|
|
|
|
|
# ---------------------------------
|
2025-09-09 14:15:16 +02:00
|
|
|
|
ROOT = Path("/home/fedex/mt/results/copy")
|
2025-09-03 14:55:54 +02:00
|
|
|
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
|
|
|
|
|
|
|
|
|
|
|
|
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
|
|
|
|
|
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
|
|
|
|
|
EVALS = ["exp_based", "manual_based"]
|
2025-09-27 16:34:52 +02:00
|
|
|
|
EVALS_LABELS = {
|
|
|
|
|
|
"exp_based": "Experiment-Based Labels",
|
|
|
|
|
|
"manual_based": "Manually-Labeled",
|
|
|
|
|
|
}
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
# Interp grids
|
|
|
|
|
|
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
|
|
|
|
|
PRC_GRID = np.linspace(0.0, 1.0, 200)
|
|
|
|
|
|
|
|
|
|
|
|
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
|
|
|
|
|
|
BASELINE_NET = "Efficient"
|
2025-09-27 16:34:52 +02:00
|
|
|
|
BASELINE_LABELS = {
|
|
|
|
|
|
"isoforest": "Isolation Forest",
|
|
|
|
|
|
"ocsvm": "One-Class SVM",
|
|
|
|
|
|
}
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
# Colors/styles
|
|
|
|
|
|
COLOR_BASELINES = {
|
|
|
|
|
|
"isoforest": "tab:purple",
|
|
|
|
|
|
"ocsvm": "tab:green",
|
|
|
|
|
|
}
|
|
|
|
|
|
COLOR_REGIMES = {
|
|
|
|
|
|
(0, 0): "tab:blue",
|
|
|
|
|
|
(50, 10): "tab:orange",
|
|
|
|
|
|
(500, 100): "tab:red",
|
|
|
|
|
|
}
|
|
|
|
|
|
LINESTYLES = {
|
|
|
|
|
|
(0, 0): "-",
|
|
|
|
|
|
(50, 10): "--",
|
|
|
|
|
|
(500, 100): "-.",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------
|
|
|
|
|
|
# Helpers
|
|
|
|
|
|
# ---------------------------------
|
|
|
|
|
|
def _net_label_col(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
return df.with_columns(
|
|
|
|
|
|
pl.when(
|
|
|
|
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
|
|
|
|
|
)
|
|
|
|
|
|
.then(pl.lit("LeNet"))
|
|
|
|
|
|
.when(
|
|
|
|
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
|
|
|
|
|
)
|
|
|
|
|
|
.then(pl.lit("Efficient"))
|
|
|
|
|
|
.otherwise(pl.col("network").cast(pl.Utf8))
|
|
|
|
|
|
.alias("net_label")
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean_ci(values: list[float], confidence: float = 0.95) -> tuple[float, float]:
|
|
|
|
|
|
"""Return mean and half-width of the (approx) confidence interval. Robust to n<2."""
|
|
|
|
|
|
arr = np.asarray([v for v in values if v is not None], dtype=float)
|
|
|
|
|
|
if arr.size == 0:
|
|
|
|
|
|
return np.nan, np.nan
|
|
|
|
|
|
if arr.size == 1:
|
|
|
|
|
|
return float(arr[0]), 0.0
|
|
|
|
|
|
m = float(arr.mean())
|
|
|
|
|
|
s = sem(arr, nan_policy="omit")
|
|
|
|
|
|
h = s * t.ppf((1 + confidence) / 2.0, arr.size - 1)
|
|
|
|
|
|
return m, float(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray):
|
|
|
|
|
|
"""Interpolate many (x,y) onto grid and return mean±std; robust to duplicates/empty."""
|
|
|
|
|
|
if not curves:
|
|
|
|
|
|
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
|
|
|
|
|
|
interps = []
|
|
|
|
|
|
for x, y in curves:
|
|
|
|
|
|
if x is None or y is None:
|
|
|
|
|
|
continue
|
|
|
|
|
|
x = np.asarray(x, float)
|
|
|
|
|
|
y = np.asarray(y, float)
|
|
|
|
|
|
if x.size == 0 or y.size == 0 or x.size != y.size:
|
|
|
|
|
|
continue
|
|
|
|
|
|
order = np.argsort(x)
|
|
|
|
|
|
x, y = x[order], y[order]
|
|
|
|
|
|
x, uniq_idx = np.unique(x, return_index=True)
|
|
|
|
|
|
y = y[uniq_idx]
|
|
|
|
|
|
g = np.clip(grid, x[0], x[-1])
|
|
|
|
|
|
yi = np.interp(g, x, y)
|
|
|
|
|
|
interps.append(yi)
|
|
|
|
|
|
if not interps:
|
|
|
|
|
|
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
|
|
|
|
|
|
A = np.vstack(interps)
|
|
|
|
|
|
return np.nanmean(A, axis=0), np.nanstd(A, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
|
|
|
|
curves = []
|
|
|
|
|
|
for r in rows:
|
|
|
|
|
|
if kind == "roc":
|
|
|
|
|
|
c = r.get("roc_curve")
|
|
|
|
|
|
if not c:
|
|
|
|
|
|
continue
|
|
|
|
|
|
x, y = c.get("fpr"), c.get("tpr")
|
|
|
|
|
|
else:
|
|
|
|
|
|
c = r.get("prc_curve")
|
|
|
|
|
|
if not c:
|
|
|
|
|
|
continue
|
|
|
|
|
|
x, y = c.get("recall"), c.get("precision")
|
|
|
|
|
|
if x is None or y is None:
|
|
|
|
|
|
continue
|
|
|
|
|
|
curves.append((np.asarray(x, float), np.asarray(y, float)))
|
|
|
|
|
|
return curves
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _select_rows(
|
|
|
|
|
|
df: pl.DataFrame,
|
|
|
|
|
|
*,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
eval_type: str,
|
|
|
|
|
|
latent_dim: int,
|
|
|
|
|
|
semi_normals: int | None = None,
|
|
|
|
|
|
semi_anomalous: int | None = None,
|
|
|
|
|
|
net_label: str | None = None,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
exprs = [
|
|
|
|
|
|
pl.col("model") == model,
|
|
|
|
|
|
pl.col("eval") == eval_type,
|
|
|
|
|
|
pl.col("latent_dim") == latent_dim,
|
|
|
|
|
|
]
|
|
|
|
|
|
if semi_normals is not None:
|
|
|
|
|
|
exprs.append(pl.col("semi_normals") == semi_normals)
|
|
|
|
|
|
if semi_anomalous is not None:
|
|
|
|
|
|
exprs.append(pl.col("semi_anomalous") == semi_anomalous)
|
|
|
|
|
|
if net_label is not None:
|
|
|
|
|
|
exprs.append(pl.col("net_label") == net_label)
|
|
|
|
|
|
return df.filter(pl.all_horizontal(exprs))
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-09-27 16:34:52 +02:00
|
|
|
|
def _auc_list(sub: pl.DataFrame, kind: str) -> list[float]:
|
|
|
|
|
|
return [x for x in sub.select(f"{kind}_auc").to_series().to_list() if x is not None]
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _plot_panel(
|
|
|
|
|
|
ax,
|
|
|
|
|
|
df: pl.DataFrame,
|
|
|
|
|
|
*,
|
|
|
|
|
|
eval_type: str,
|
|
|
|
|
|
net_for_deepsad: str,
|
|
|
|
|
|
latent_dim: int,
|
|
|
|
|
|
kind: str,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
2025-09-27 16:34:52 +02:00
|
|
|
|
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + Baselines (from Efficient).
|
2025-09-03 14:55:54 +02:00
|
|
|
|
Legend entries include mean±CI of AUC/AP.
|
|
|
|
|
|
"""
|
|
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
|
ax.set_xlim(0, 1)
|
|
|
|
|
|
ax.set_ylim(0, 1)
|
|
|
|
|
|
if kind == "roc":
|
|
|
|
|
|
ax.set_xlabel("FPR")
|
|
|
|
|
|
ax.set_ylabel("TPR")
|
|
|
|
|
|
grid = ROC_GRID
|
|
|
|
|
|
else:
|
|
|
|
|
|
ax.set_xlabel("Recall")
|
|
|
|
|
|
ax.set_ylabel("Precision")
|
|
|
|
|
|
grid = PRC_GRID
|
|
|
|
|
|
|
|
|
|
|
|
handles, labels = [], []
|
|
|
|
|
|
|
|
|
|
|
|
# ----- Baselines (Efficient)
|
|
|
|
|
|
for model in ("isoforest", "ocsvm"):
|
|
|
|
|
|
sub_b = _select_rows(
|
|
|
|
|
|
df,
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
net_label=BASELINE_NET,
|
|
|
|
|
|
)
|
|
|
|
|
|
if sub_b.height == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
rows = sub_b.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
|
|
|
|
|
curves = _extract_curves(rows, kind)
|
|
|
|
|
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
|
|
|
|
|
if np.all(np.isnan(mean_y)):
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# Metric for legend
|
2025-09-27 16:34:52 +02:00
|
|
|
|
metric_vals = _auc_list(sub_b, kind)
|
2025-09-03 14:55:54 +02:00
|
|
|
|
m, ci = mean_ci(metric_vals)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
lab = f"{BASELINE_LABELS[model]}\n(AUC={m:.3f}±{ci:.3f})"
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
color = COLOR_BASELINES[model]
|
|
|
|
|
|
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
|
|
|
|
|
|
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
|
|
|
|
|
|
handles.append(h)
|
|
|
|
|
|
labels.append(lab)
|
|
|
|
|
|
|
|
|
|
|
|
# ----- DeepSAD (this panel's net) across semi-regimes
|
|
|
|
|
|
for regime in SEMI_REGIMES:
|
|
|
|
|
|
sn, sa = regime
|
|
|
|
|
|
sub_d = _select_rows(
|
|
|
|
|
|
df,
|
|
|
|
|
|
model="deepsad",
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
semi_normals=sn,
|
|
|
|
|
|
semi_anomalous=sa,
|
|
|
|
|
|
net_label=net_for_deepsad,
|
|
|
|
|
|
)
|
|
|
|
|
|
if sub_d.height == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
rows = sub_d.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
|
|
|
|
|
curves = _extract_curves(rows, kind)
|
|
|
|
|
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
|
|
|
|
|
if np.all(np.isnan(mean_y)):
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
2025-09-27 16:34:52 +02:00
|
|
|
|
metric_vals = _auc_list(sub_d, kind)
|
2025-09-03 14:55:54 +02:00
|
|
|
|
m, ci = mean_ci(metric_vals)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
lab = f"DeepSAD {net_for_deepsad} — {sn}/{sa}\n(AUC={m:.3f}±{ci:.3f})"
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
color = COLOR_REGIMES[regime]
|
|
|
|
|
|
ls = LINESTYLES[regime]
|
|
|
|
|
|
h = ax.plot(grid, mean_y, lw=2, color=color, linestyle=ls, label=lab)[0]
|
|
|
|
|
|
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
|
|
|
|
|
|
handles.append(h)
|
|
|
|
|
|
labels.append(lab)
|
|
|
|
|
|
|
|
|
|
|
|
# Chance line for ROC
|
|
|
|
|
|
if kind == "roc":
|
|
|
|
|
|
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
|
|
|
|
|
|
|
|
|
|
|
|
# Legend
|
2025-09-27 16:34:52 +02:00
|
|
|
|
ax.legend(loc="upper right", fontsize=9, frameon=True)
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_figures_for_dim(
|
|
|
|
|
|
df: pl.DataFrame, eval_type: str, latent_dim: int, out_dir: Path
|
|
|
|
|
|
):
|
|
|
|
|
|
# ROC: 2×1
|
|
|
|
|
|
fig_roc, axes = plt.subplots(
|
2025-09-27 16:34:52 +02:00
|
|
|
|
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
|
|
|
|
|
)
|
|
|
|
|
|
fig_roc.suptitle(
|
|
|
|
|
|
f"ROC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
2025-09-03 14:55:54 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
_plot_panel(
|
|
|
|
|
|
axes[0],
|
|
|
|
|
|
df,
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
net_for_deepsad="LeNet",
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
kind="roc",
|
|
|
|
|
|
)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
_plot_panel(
|
|
|
|
|
|
axes[1],
|
|
|
|
|
|
df,
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
net_for_deepsad="Efficient",
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
kind="roc",
|
|
|
|
|
|
)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
|
|
|
|
|
|
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
|
|
|
|
|
|
plt.close(fig_roc)
|
|
|
|
|
|
|
|
|
|
|
|
# PRC: 2×1
|
|
|
|
|
|
fig_prc, axes = plt.subplots(
|
2025-09-27 16:34:52 +02:00
|
|
|
|
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
|
|
|
|
|
)
|
|
|
|
|
|
fig_prc.suptitle(
|
|
|
|
|
|
f"PRC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
2025-09-03 14:55:54 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
_plot_panel(
|
|
|
|
|
|
axes[0],
|
|
|
|
|
|
df,
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
net_for_deepsad="LeNet",
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
kind="prc",
|
|
|
|
|
|
)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
_plot_panel(
|
|
|
|
|
|
axes[1],
|
|
|
|
|
|
df,
|
|
|
|
|
|
eval_type=eval_type,
|
|
|
|
|
|
net_for_deepsad="Efficient",
|
|
|
|
|
|
latent_dim=latent_dim,
|
|
|
|
|
|
kind="prc",
|
|
|
|
|
|
)
|
2025-09-27 16:34:52 +02:00
|
|
|
|
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
2025-09-03 14:55:54 +02:00
|
|
|
|
|
|
|
|
|
|
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
|
|
|
|
|
|
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")
|
|
|
|
|
|
plt.close(fig_prc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
# Load dataframe and prep
|
|
|
|
|
|
df = load_results_dataframe(ROOT, allow_cache=True)
|
|
|
|
|
|
df = _net_label_col(df)
|
|
|
|
|
|
|
|
|
|
|
|
# Filter to relevant models/evals only once
|
|
|
|
|
|
df = df.filter(
|
|
|
|
|
|
(pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
|
|
|
|
|
|
& (pl.col("eval").is_in(EVALS))
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Output/archiving like your AE script
|
|
|
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
archive = OUTPUT_DIR / "archive"
|
|
|
|
|
|
archive.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
ts_dir = archive / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# Generate figures
|
|
|
|
|
|
for eval_type in EVALS:
|
|
|
|
|
|
for dim in LATENT_DIMS:
|
|
|
|
|
|
make_figures_for_dim(
|
|
|
|
|
|
df, eval_type=eval_type, latent_dim=dim, out_dir=ts_dir
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Copy this script for provenance
|
|
|
|
|
|
script_path = Path(__file__)
|
|
|
|
|
|
try:
|
|
|
|
|
|
shutil.copy2(script_path, ts_dir)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass # best effort if running in environments where __file__ may not exist
|
|
|
|
|
|
|
|
|
|
|
|
# Update "latest"
|
|
|
|
|
|
latest = OUTPUT_DIR / "latest"
|
|
|
|
|
|
latest.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
for f in latest.iterdir():
|
|
|
|
|
|
if f.is_file():
|
|
|
|
|
|
f.unlink()
|
|
|
|
|
|
for f in ts_dir.iterdir():
|
|
|
|
|
|
if f.is_file():
|
|
|
|
|
|
shutil.copy2(f, latest / f.name)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Saved plots to: {ts_dir}")
|
|
|
|
|
|
print(f"Also updated: {latest}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|