Files
mt/tools/plot_scripts/results_semi_labels_comparison.py
Jan Kowalczyk 86d9d96ca4 wip
2025-09-09 14:15:16 +02:00

364 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# curves_2x1_by_net_with_regimes_from_df.py
from __future__ import annotations
import shutil
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.lines import Line2D
from scipy.stats import sem, t
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
from load_results import load_results_dataframe
# ---------------------------------
# Config
# ---------------------------------
ROOT = Path("/home/fedex/mt/results/copy")
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
EVALS = ["exp_based", "manual_based"]
# Interp grids
ROC_GRID = np.linspace(0.0, 1.0, 200)
PRC_GRID = np.linspace(0.0, 1.0, 200)
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
BASELINE_NET = "Efficient"
# Colors/styles
COLOR_BASELINES = {
"isoforest": "tab:purple",
"ocsvm": "tab:green",
}
COLOR_REGIMES = {
(0, 0): "tab:blue",
(50, 10): "tab:orange",
(500, 100): "tab:red",
}
LINESTYLES = {
(0, 0): "-",
(50, 10): "--",
(500, 100): "-.",
}
# ---------------------------------
# Helpers
# ---------------------------------
def _net_label_col(df: pl.DataFrame) -> pl.DataFrame:
return df.with_columns(
pl.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
)
.then(pl.lit("LeNet"))
.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
)
.then(pl.lit("Efficient"))
.otherwise(pl.col("network").cast(pl.Utf8))
.alias("net_label")
)
def mean_ci(values: list[float], confidence: float = 0.95) -> tuple[float, float]:
"""Return mean and half-width of the (approx) confidence interval. Robust to n<2."""
arr = np.asarray([v for v in values if v is not None], dtype=float)
if arr.size == 0:
return np.nan, np.nan
if arr.size == 1:
return float(arr[0]), 0.0
m = float(arr.mean())
s = sem(arr, nan_policy="omit")
h = s * t.ppf((1 + confidence) / 2.0, arr.size - 1)
return m, float(h)
def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray):
"""Interpolate many (x,y) onto grid and return mean±std; robust to duplicates/empty."""
if not curves:
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
interps = []
for x, y in curves:
if x is None or y is None:
continue
x = np.asarray(x, float)
y = np.asarray(y, float)
if x.size == 0 or y.size == 0 or x.size != y.size:
continue
order = np.argsort(x)
x, y = x[order], y[order]
x, uniq_idx = np.unique(x, return_index=True)
y = y[uniq_idx]
g = np.clip(grid, x[0], x[-1])
yi = np.interp(g, x, y)
interps.append(yi)
if not interps:
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
A = np.vstack(interps)
return np.nanmean(A, axis=0), np.nanstd(A, axis=0)
def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]:
curves = []
for r in rows:
if kind == "roc":
c = r.get("roc_curve")
if not c:
continue
x, y = c.get("fpr"), c.get("tpr")
else:
c = r.get("prc_curve")
if not c:
continue
x, y = c.get("recall"), c.get("precision")
if x is None or y is None:
continue
curves.append((np.asarray(x, float), np.asarray(y, float)))
return curves
def _select_rows(
df: pl.DataFrame,
*,
model: str,
eval_type: str,
latent_dim: int,
semi_normals: int | None = None,
semi_anomalous: int | None = None,
net_label: str | None = None,
) -> pl.DataFrame:
exprs = [
pl.col("model") == model,
pl.col("eval") == eval_type,
pl.col("latent_dim") == latent_dim,
]
if semi_normals is not None:
exprs.append(pl.col("semi_normals") == semi_normals)
if semi_anomalous is not None:
exprs.append(pl.col("semi_anomalous") == semi_anomalous)
if net_label is not None:
exprs.append(pl.col("net_label") == net_label)
return df.filter(pl.all_horizontal(exprs))
def _auc_list(sub: pl.DataFrame) -> list[float]:
return [x for x in sub.select("auc").to_series().to_list() if x is not None]
def _ap_list(sub: pl.DataFrame) -> list[float]:
return [x for x in sub.select("ap").to_series().to_list() if x is not None]
def _plot_panel(
ax,
df: pl.DataFrame,
*,
eval_type: str,
net_for_deepsad: str,
latent_dim: int,
kind: str,
):
"""
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient).
Legend entries include mean±CI of AUC/AP.
"""
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
if kind == "roc":
ax.set_xlabel("FPR")
ax.set_ylabel("TPR")
grid = ROC_GRID
else:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
grid = PRC_GRID
handles, labels = [], []
# ----- Baselines (Efficient)
for model in ("isoforest", "ocsvm"):
sub_b = _select_rows(
df,
model=model,
eval_type=eval_type,
latent_dim=latent_dim,
net_label=BASELINE_NET,
)
if sub_b.height == 0:
continue
rows = sub_b.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
curves = _extract_curves(rows, kind)
mean_y, std_y = _interp_mean_std(curves, grid)
if np.all(np.isnan(mean_y)):
continue
# Metric for legend
metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b)
m, ci = mean_ci(metric_vals)
lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
color = COLOR_BASELINES[model]
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
handles.append(h)
labels.append(lab)
# ----- DeepSAD (this panel's net) across semi-regimes
for regime in SEMI_REGIMES:
sn, sa = regime
sub_d = _select_rows(
df,
model="deepsad",
eval_type=eval_type,
latent_dim=latent_dim,
semi_normals=sn,
semi_anomalous=sa,
net_label=net_for_deepsad,
)
if sub_d.height == 0:
continue
rows = sub_d.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
curves = _extract_curves(rows, kind)
mean_y, std_y = _interp_mean_std(curves, grid)
if np.all(np.isnan(mean_y)):
continue
metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d)
m, ci = mean_ci(metric_vals)
lab = f"DeepSAD {net_for_deepsad} — semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
color = COLOR_REGIMES[regime]
ls = LINESTYLES[regime]
h = ax.plot(grid, mean_y, lw=2, color=color, linestyle=ls, label=lab)[0]
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
handles.append(h)
labels.append(lab)
# Chance line for ROC
if kind == "roc":
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
# Legend
ax.legend(loc="lower right", fontsize=9, frameon=True)
def make_figures_for_dim(
df: pl.DataFrame, eval_type: str, latent_dim: int, out_dir: Path
):
# ROC: 2×1
fig_roc, axes = plt.subplots(
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
)
fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
_plot_panel(
axes[0],
df,
eval_type=eval_type,
net_for_deepsad="LeNet",
latent_dim=latent_dim,
kind="roc",
)
axes[0].set_title("DeepSAD (LeNet) + baselines")
_plot_panel(
axes[1],
df,
eval_type=eval_type,
net_for_deepsad="Efficient",
latent_dim=latent_dim,
kind="roc",
)
axes[1].set_title("DeepSAD (Efficient) + baselines")
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
plt.close(fig_roc)
# PRC: 2×1
fig_prc, axes = plt.subplots(
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
)
fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
_plot_panel(
axes[0],
df,
eval_type=eval_type,
net_for_deepsad="LeNet",
latent_dim=latent_dim,
kind="prc",
)
axes[0].set_title("DeepSAD (LeNet) + baselines")
_plot_panel(
axes[1],
df,
eval_type=eval_type,
net_for_deepsad="Efficient",
latent_dim=latent_dim,
kind="prc",
)
axes[1].set_title("DeepSAD (Efficient) + baselines")
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")
plt.close(fig_prc)
def main():
# Load dataframe and prep
df = load_results_dataframe(ROOT, allow_cache=True)
df = _net_label_col(df)
# Filter to relevant models/evals only once
df = df.filter(
(pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
& (pl.col("eval").is_in(EVALS))
)
# Output/archiving like your AE script
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
archive = OUTPUT_DIR / "archive"
archive.mkdir(parents=True, exist_ok=True)
ts_dir = archive / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
# Generate figures
for eval_type in EVALS:
for dim in LATENT_DIMS:
make_figures_for_dim(
df, eval_type=eval_type, latent_dim=dim, out_dir=ts_dir
)
# Copy this script for provenance
script_path = Path(__file__)
try:
shutil.copy2(script_path, ts_dir)
except Exception:
pass # best effort if running in environments where __file__ may not exist
# Update "latest"
latest = OUTPUT_DIR / "latest"
latest.mkdir(parents=True, exist_ok=True)
for f in latest.iterdir():
if f.is_file():
f.unlink()
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
print(f"Saved plots to: {ts_dir}")
print(f"Also updated: {latest}")
if __name__ == "__main__":
main()