260 lines
7.6 KiB
Python
260 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import shutil
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import polars as pl
|
|
from matplotlib.ticker import MaxNLocator
|
|
|
|
# =========================
|
|
# Config
|
|
# =========================
|
|
ROOT = Path("/home/fedex/mt/results/copy")
|
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_latent")
|
|
|
|
# Labeling regimes (shown as separate subplots)
|
|
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
|
|
|
|
# Evaluations: separate figure per eval
|
|
EVALS: list[str] = ["exp_based", "manual_based"]
|
|
|
|
# X-axis (latent dims)
|
|
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
|
|
|
|
# Visual style
|
|
FIGSIZE = (8, 8) # one tall figure with 3 compact subplots
|
|
MARKERSIZE = 7
|
|
SCATTER_ALPHA = 0.95
|
|
LINEWIDTH = 2.0
|
|
TREND_LINEWIDTH = 2.2
|
|
BAND_ALPHA = 0.18
|
|
|
|
# Toggle: show ±1 std bands (k-fold variability)
|
|
SHOW_STD_BANDS = True # <<< set to False to hide the bands
|
|
|
|
# Colors for the two DeepSAD backbones
|
|
COLOR_LENET = "#1f77b4" # blue
|
|
COLOR_EFFICIENT = "#ff7f0e" # orange
|
|
|
|
# =========================
|
|
# Loader
|
|
# =========================
|
|
from load_results import load_results_dataframe
|
|
|
|
|
|
# =========================
|
|
# Helpers
|
|
# =========================
|
|
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
|
|
return df.with_columns(
|
|
pl.when(
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
|
)
|
|
.then(pl.lit("LeNet"))
|
|
.when(
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
|
)
|
|
.then(pl.lit("Efficient"))
|
|
.otherwise(pl.col("network").cast(pl.Utf8))
|
|
.alias("net_label")
|
|
)
|
|
|
|
|
|
def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame:
|
|
return df.filter(
|
|
(pl.col("model") == "deepsad")
|
|
& (pl.col("eval").is_in(EVALS))
|
|
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
|
& (pl.col("net_label").is_in(["LeNet", "Efficient"]))
|
|
).select(
|
|
"eval",
|
|
"net_label",
|
|
"latent_dim",
|
|
"semi_normals",
|
|
"semi_anomalous",
|
|
"fold",
|
|
"ap",
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Agg:
|
|
mean: float
|
|
std: float
|
|
|
|
|
|
def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]:
|
|
out: Dict[Tuple[str, str, int, int, int], Agg] = {}
|
|
gb = (
|
|
df.group_by(
|
|
["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"]
|
|
)
|
|
.agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std"))
|
|
.to_dicts()
|
|
)
|
|
for row in gb:
|
|
key = (
|
|
str(row["eval"]),
|
|
str(row["net_label"]),
|
|
int(row["latent_dim"]),
|
|
int(row["semi_normals"]),
|
|
int(row["semi_anomalous"]),
|
|
)
|
|
m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan
|
|
s = float(row["std"]) if row["std"] == row["std"] else np.nan
|
|
out[key] = Agg(mean=m, std=s)
|
|
return out
|
|
|
|
|
|
def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]:
|
|
if len(xs) < 2:
|
|
return np.array(xs, dtype=float), np.array(ys, dtype=float)
|
|
x = np.array(xs, dtype=float)
|
|
y = np.array(ys, dtype=float)
|
|
a, b = np.polyfit(x, y, 1)
|
|
x_fit = np.linspace(x.min(), x.max(), 200)
|
|
y_fit = a * x_fit + b
|
|
return x_fit, y_fit
|
|
|
|
|
|
def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]:
|
|
vals = np.array(all_vals, dtype=float)
|
|
errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals)
|
|
valid = np.isfinite(vals)
|
|
if not np.any(valid):
|
|
return (0.0, 1.0)
|
|
v, e = vals[valid], errs[valid]
|
|
lo = np.min(v - e)
|
|
hi = np.max(v + e)
|
|
span = max(1e-3, hi - lo)
|
|
pad = 0.08 * span
|
|
y0 = max(0.0, lo - pad)
|
|
y1 = min(1.0, hi + pad)
|
|
if (y1 - y0) < 0.08:
|
|
mid = 0.5 * (y0 + y1)
|
|
y0 = max(0.0, mid - 0.04)
|
|
y1 = min(1.0, mid + 0.04)
|
|
return (float(y0), float(y1))
|
|
|
|
|
|
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
|
fig, axes = plt.subplots(
|
|
len(SEMI_LABELING_REGIMES),
|
|
1,
|
|
figsize=FIGSIZE,
|
|
constrained_layout=True,
|
|
sharex=True,
|
|
)
|
|
|
|
if len(SEMI_LABELING_REGIMES) == 1:
|
|
axes = [axes]
|
|
|
|
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
|
|
semi_n, semi_a = regime
|
|
data = {}
|
|
for net in ["LeNet", "Efficient"]:
|
|
xs, ys, es = [], [], []
|
|
for dim in LATENT_DIMS:
|
|
key = (ev, net, dim, semi_n, semi_a)
|
|
if key in agg:
|
|
xs.append(dim)
|
|
ys.append(agg[key].mean)
|
|
es.append(agg[key].std)
|
|
data[net] = (xs, ys, es)
|
|
|
|
for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]:
|
|
xs, ys, es = data[net]
|
|
if not xs:
|
|
continue
|
|
ax.set_xticks(LATENT_DIMS)
|
|
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
|
|
ax.scatter(
|
|
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
|
|
)
|
|
x_fit, y_fit = _lin_trend(xs, ys)
|
|
ax.plot(
|
|
x_fit,
|
|
y_fit,
|
|
color=color,
|
|
linewidth=TREND_LINEWIDTH,
|
|
label=f"{net} (trend)",
|
|
)
|
|
if SHOW_STD_BANDS and es and np.any(np.isfinite(es)):
|
|
ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0)
|
|
yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0)
|
|
ax.fill_between(
|
|
xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0
|
|
)
|
|
|
|
all_vals, all_errs = [], []
|
|
for net in ["LeNet", "Efficient"]:
|
|
_, ys, es = data[net]
|
|
all_vals.extend(ys)
|
|
all_errs.extend(es)
|
|
y0, y1 = _dynamic_ylim(all_vals, all_errs)
|
|
ax.set_ylim(y0, y1)
|
|
|
|
ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11)
|
|
ax.grid(True, alpha=0.35)
|
|
|
|
axes[-1].set_xlabel("Latent dimension")
|
|
for ax in axes:
|
|
ax.set_ylabel("AP")
|
|
|
|
handles, labels = axes[0].get_legend_handles_labels()
|
|
fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97))
|
|
fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05)
|
|
|
|
fname = f"ap_trends_{ev}.png"
|
|
fig.savefig(outdir / fname, dpi=150)
|
|
plt.close(fig)
|
|
|
|
|
|
def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
|
outdir.mkdir(parents=True, exist_ok=True)
|
|
for ev in EVALS:
|
|
plot_eval(ev, agg, outdir)
|
|
|
|
|
|
def main():
|
|
df = load_results_dataframe(ROOT, allow_cache=True)
|
|
df = _with_net_label(df)
|
|
df = _filter_deepsad(df)
|
|
agg = aggregate_ap(df)
|
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
archive_dir = OUTPUT_DIR / "archive"
|
|
archive_dir.mkdir(parents=True, exist_ok=True)
|
|
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
plot_all(agg, ts_dir)
|
|
|
|
try:
|
|
script_path = Path(__file__)
|
|
shutil.copy2(script_path, ts_dir / script_path.name)
|
|
except Exception:
|
|
pass
|
|
|
|
latest = OUTPUT_DIR / "latest"
|
|
latest.mkdir(parents=True, exist_ok=True)
|
|
for f in latest.iterdir():
|
|
if f.is_file():
|
|
f.unlink()
|
|
for f in ts_dir.iterdir():
|
|
if f.is_file():
|
|
shutil.copy2(f, latest / f.name)
|
|
|
|
print(f"Saved plots to: {ts_dir}")
|
|
print(f"Also updated: {latest}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|