Files
mt/tools/plot_scripts/results_ap_over_latent.py

260 lines
7.6 KiB
Python
Raw Normal View History

2025-09-27 16:34:52 +02:00
#!/usr/bin/env python3
from __future__ import annotations
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.ticker import MaxNLocator
# =========================
# Config
# =========================
ROOT = Path("/home/fedex/mt/results/copy")
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_latent")
# Labeling regimes (shown as separate subplots)
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
# Evaluations: separate figure per eval
EVALS: list[str] = ["exp_based", "manual_based"]
# X-axis (latent dims)
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
# Visual style
FIGSIZE = (8, 8) # one tall figure with 3 compact subplots
MARKERSIZE = 7
SCATTER_ALPHA = 0.95
LINEWIDTH = 2.0
TREND_LINEWIDTH = 2.2
BAND_ALPHA = 0.18
# Toggle: show ±1 std bands (k-fold variability)
SHOW_STD_BANDS = True # <<< set to False to hide the bands
# Colors for the two DeepSAD backbones
COLOR_LENET = "#1f77b4" # blue
COLOR_EFFICIENT = "#ff7f0e" # orange
# =========================
# Loader
# =========================
from load_results import load_results_dataframe
# =========================
# Helpers
# =========================
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
return df.with_columns(
pl.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
)
.then(pl.lit("LeNet"))
.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
)
.then(pl.lit("Efficient"))
.otherwise(pl.col("network").cast(pl.Utf8))
.alias("net_label")
)
def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame:
return df.filter(
(pl.col("model") == "deepsad")
& (pl.col("eval").is_in(EVALS))
& (pl.col("latent_dim").is_in(LATENT_DIMS))
& (pl.col("net_label").is_in(["LeNet", "Efficient"]))
).select(
"eval",
"net_label",
"latent_dim",
"semi_normals",
"semi_anomalous",
"fold",
"ap",
)
@dataclass(frozen=True)
class Agg:
mean: float
std: float
def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]:
out: Dict[Tuple[str, str, int, int, int], Agg] = {}
gb = (
df.group_by(
["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"]
)
.agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std"))
.to_dicts()
)
for row in gb:
key = (
str(row["eval"]),
str(row["net_label"]),
int(row["latent_dim"]),
int(row["semi_normals"]),
int(row["semi_anomalous"]),
)
m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan
s = float(row["std"]) if row["std"] == row["std"] else np.nan
out[key] = Agg(mean=m, std=s)
return out
def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]:
if len(xs) < 2:
return np.array(xs, dtype=float), np.array(ys, dtype=float)
x = np.array(xs, dtype=float)
y = np.array(ys, dtype=float)
a, b = np.polyfit(x, y, 1)
x_fit = np.linspace(x.min(), x.max(), 200)
y_fit = a * x_fit + b
return x_fit, y_fit
def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]:
vals = np.array(all_vals, dtype=float)
errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals)
valid = np.isfinite(vals)
if not np.any(valid):
return (0.0, 1.0)
v, e = vals[valid], errs[valid]
lo = np.min(v - e)
hi = np.max(v + e)
span = max(1e-3, hi - lo)
pad = 0.08 * span
y0 = max(0.0, lo - pad)
y1 = min(1.0, hi + pad)
if (y1 - y0) < 0.08:
mid = 0.5 * (y0 + y1)
y0 = max(0.0, mid - 0.04)
y1 = min(1.0, mid + 0.04)
return (float(y0), float(y1))
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
fig, axes = plt.subplots(
len(SEMI_LABELING_REGIMES),
1,
figsize=FIGSIZE,
constrained_layout=True,
sharex=True,
)
if len(SEMI_LABELING_REGIMES) == 1:
axes = [axes]
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
semi_n, semi_a = regime
data = {}
for net in ["LeNet", "Efficient"]:
xs, ys, es = [], [], []
for dim in LATENT_DIMS:
key = (ev, net, dim, semi_n, semi_a)
if key in agg:
xs.append(dim)
ys.append(agg[key].mean)
es.append(agg[key].std)
data[net] = (xs, ys, es)
for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]:
xs, ys, es = data[net]
if not xs:
continue
ax.set_xticks(LATENT_DIMS)
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
ax.scatter(
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
)
x_fit, y_fit = _lin_trend(xs, ys)
ax.plot(
x_fit,
y_fit,
color=color,
linewidth=TREND_LINEWIDTH,
label=f"{net} (trend)",
)
if SHOW_STD_BANDS and es and np.any(np.isfinite(es)):
ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0)
yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0)
ax.fill_between(
xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0
)
all_vals, all_errs = [], []
for net in ["LeNet", "Efficient"]:
_, ys, es = data[net]
all_vals.extend(ys)
all_errs.extend(es)
y0, y1 = _dynamic_ylim(all_vals, all_errs)
ax.set_ylim(y0, y1)
ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11)
ax.grid(True, alpha=0.35)
axes[-1].set_xlabel("Latent dimension")
for ax in axes:
ax.set_ylabel("AP")
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97))
fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05)
fname = f"ap_trends_{ev}.png"
fig.savefig(outdir / fname, dpi=150)
plt.close(fig)
def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
outdir.mkdir(parents=True, exist_ok=True)
for ev in EVALS:
plot_eval(ev, agg, outdir)
def main():
df = load_results_dataframe(ROOT, allow_cache=True)
df = _with_net_label(df)
df = _filter_deepsad(df)
agg = aggregate_ap(df)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
archive_dir = OUTPUT_DIR / "archive"
archive_dir.mkdir(parents=True, exist_ok=True)
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
plot_all(agg, ts_dir)
try:
script_path = Path(__file__)
shutil.copy2(script_path, ts_dir / script_path.name)
except Exception:
pass
latest = OUTPUT_DIR / "latest"
latest.mkdir(parents=True, exist_ok=True)
for f in latest.iterdir():
if f.is_file():
f.unlink()
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
print(f"Saved plots to: {ts_dir}")
print(f"Also updated: {latest}")
if __name__ == "__main__":
main()