Files
mt/tools/plot_scripts/results_ap_over_latent.py

274 lines
8.1 KiB
Python
Raw Permalink Normal View History

2025-09-27 16:34:52 +02:00
#!/usr/bin/env python3
from __future__ import annotations
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.ticker import MaxNLocator
# =========================
# Config
# =========================
ROOT = Path("/home/fedex/mt/results/copy")
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_latent")
# Labeling regimes (shown as separate subplots)
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
# Evaluations: separate figure per eval
EVALS: list[str] = ["exp_based", "manual_based"]
# X-axis (latent dims)
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
# Visual style
FIGSIZE = (8, 8) # one tall figure with 3 compact subplots
MARKERSIZE = 7
SCATTER_ALPHA = 0.95
LINEWIDTH = 2.0
TREND_LINEWIDTH = 2.2
BAND_ALPHA = 0.18
# Toggle: show ±1 std bands (k-fold variability)
SHOW_STD_BANDS = True # <<< set to False to hide the bands
# Colors for the two DeepSAD backbones
COLOR_LENET = "#1f77b4" # blue
COLOR_EFFICIENT = "#ff7f0e" # orange
# =========================
# Loader
# =========================
from load_results import load_results_dataframe
# =========================
# Helpers
# =========================
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
return df.with_columns(
pl.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
)
.then(pl.lit("LeNet"))
.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
)
.then(pl.lit("Efficient"))
.otherwise(pl.col("network").cast(pl.Utf8))
.alias("net_label")
)
def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame:
return df.filter(
(pl.col("model") == "deepsad")
& (pl.col("eval").is_in(EVALS))
& (pl.col("latent_dim").is_in(LATENT_DIMS))
& (pl.col("net_label").is_in(["LeNet", "Efficient"]))
).select(
"eval",
"net_label",
"latent_dim",
"semi_normals",
"semi_anomalous",
"fold",
"ap",
)
@dataclass(frozen=True)
class Agg:
mean: float
std: float
def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]:
out: Dict[Tuple[str, str, int, int, int], Agg] = {}
gb = (
df.group_by(
["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"]
)
.agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std"))
.to_dicts()
)
for row in gb:
key = (
str(row["eval"]),
str(row["net_label"]),
int(row["latent_dim"]),
int(row["semi_normals"]),
int(row["semi_anomalous"]),
)
m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan
s = float(row["std"]) if row["std"] == row["std"] else np.nan
out[key] = Agg(mean=m, std=s)
return out
def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]:
if len(xs) < 2:
return np.array(xs, dtype=float), np.array(ys, dtype=float)
x = np.array(xs, dtype=float)
y = np.array(ys, dtype=float)
a, b = np.polyfit(x, y, 1)
x_fit = np.linspace(x.min(), x.max(), 200)
y_fit = a * x_fit + b
return x_fit, y_fit
def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]:
vals = np.array(all_vals, dtype=float)
errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals)
valid = np.isfinite(vals)
if not np.any(valid):
return (0.0, 1.0)
v, e = vals[valid], errs[valid]
lo = np.min(v - e)
hi = np.max(v + e)
span = max(1e-3, hi - lo)
pad = 0.08 * span
y0 = max(0.0, lo - pad)
y1 = min(1.0, hi + pad)
if (y1 - y0) < 0.08:
mid = 0.5 * (y0 + y1)
y0 = max(0.0, mid - 0.04)
y1 = min(1.0, mid + 0.04)
return (float(y0), float(y1))
2025-09-27 19:01:59 +02:00
def _get_dim_mapping(dims: list[int]) -> dict[int, int]:
"""Map actual dimensions to evenly spaced positions (0, 1, 2, ...)"""
return {dim: i for i, dim in enumerate(dims)}
2025-09-27 16:34:52 +02:00
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
fig, axes = plt.subplots(
len(SEMI_LABELING_REGIMES),
1,
figsize=FIGSIZE,
constrained_layout=True,
sharex=True,
)
if len(SEMI_LABELING_REGIMES) == 1:
axes = [axes]
2025-09-27 19:01:59 +02:00
# Create dimension mapping
dim_mapping = _get_dim_mapping(LATENT_DIMS)
2025-09-27 16:34:52 +02:00
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
semi_n, semi_a = regime
data = {}
for net in ["LeNet", "Efficient"]:
xs, ys, es = [], [], []
for dim in LATENT_DIMS:
key = (ev, net, dim, semi_n, semi_a)
if key in agg:
2025-09-27 19:01:59 +02:00
xs.append(
dim_mapping[dim]
) # Use mapped position instead of actual dim
2025-09-27 16:34:52 +02:00
ys.append(agg[key].mean)
es.append(agg[key].std)
data[net] = (xs, ys, es)
for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]:
xs, ys, es = data[net]
if not xs:
continue
2025-09-27 19:01:59 +02:00
# Set evenly spaced ticks with actual dimension labels
ax.set_xticks(list(dim_mapping.values()))
ax.set_xticklabels(LATENT_DIMS)
ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
2025-09-27 16:34:52 +02:00
ax.scatter(
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
)
2025-09-27 19:01:59 +02:00
x_fit, y_fit = _lin_trend(xs, ys) # Now using mapped positions
2025-09-27 16:34:52 +02:00
ax.plot(
x_fit,
y_fit,
color=color,
linewidth=TREND_LINEWIDTH,
label=f"{net} (trend)",
)
if SHOW_STD_BANDS and es and np.any(np.isfinite(es)):
ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0)
yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0)
ax.fill_between(
xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0
)
all_vals, all_errs = [], []
for net in ["LeNet", "Efficient"]:
_, ys, es = data[net]
all_vals.extend(ys)
all_errs.extend(es)
y0, y1 = _dynamic_ylim(all_vals, all_errs)
ax.set_ylim(y0, y1)
ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11)
ax.grid(True, alpha=0.35)
axes[-1].set_xlabel("Latent dimension")
for ax in axes:
ax.set_ylabel("AP")
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97))
fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05)
fname = f"ap_trends_{ev}.png"
fig.savefig(outdir / fname, dpi=150)
plt.close(fig)
def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
outdir.mkdir(parents=True, exist_ok=True)
for ev in EVALS:
plot_eval(ev, agg, outdir)
def main():
df = load_results_dataframe(ROOT, allow_cache=True)
df = _with_net_label(df)
df = _filter_deepsad(df)
agg = aggregate_ap(df)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
archive_dir = OUTPUT_DIR / "archive"
archive_dir.mkdir(parents=True, exist_ok=True)
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
plot_all(agg, ts_dir)
try:
script_path = Path(__file__)
shutil.copy2(script_path, ts_dir / script_path.name)
except Exception:
pass
latest = OUTPUT_DIR / "latest"
latest.mkdir(parents=True, exist_ok=True)
for f in latest.iterdir():
if f.is_file():
f.unlink()
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
print(f"Saved plots to: {ts_dir}")
print(f"Also updated: {latest}")
if __name__ == "__main__":
main()