wip
This commit is contained in:
259
tools/plot_scripts/results_ap_over_latent.py
Normal file
259
tools/plot_scripts/results_ap_over_latent.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
# =========================
|
||||
# Config
|
||||
# =========================
|
||||
ROOT = Path("/home/fedex/mt/results/copy")
|
||||
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_latent")
|
||||
|
||||
# Labeling regimes (shown as separate subplots)
|
||||
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
|
||||
|
||||
# Evaluations: separate figure per eval
|
||||
EVALS: list[str] = ["exp_based", "manual_based"]
|
||||
|
||||
# X-axis (latent dims)
|
||||
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
|
||||
|
||||
# Visual style
|
||||
FIGSIZE = (8, 8) # one tall figure with 3 compact subplots
|
||||
MARKERSIZE = 7
|
||||
SCATTER_ALPHA = 0.95
|
||||
LINEWIDTH = 2.0
|
||||
TREND_LINEWIDTH = 2.2
|
||||
BAND_ALPHA = 0.18
|
||||
|
||||
# Toggle: show ±1 std bands (k-fold variability)
|
||||
SHOW_STD_BANDS = True # <<< set to False to hide the bands
|
||||
|
||||
# Colors for the two DeepSAD backbones
|
||||
COLOR_LENET = "#1f77b4" # blue
|
||||
COLOR_EFFICIENT = "#ff7f0e" # orange
|
||||
|
||||
# =========================
|
||||
# Loader
|
||||
# =========================
|
||||
from load_results import load_results_dataframe
|
||||
|
||||
|
||||
# =========================
|
||||
# Helpers
|
||||
# =========================
|
||||
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
|
||||
return df.with_columns(
|
||||
pl.when(
|
||||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
||||
)
|
||||
.then(pl.lit("LeNet"))
|
||||
.when(
|
||||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
||||
)
|
||||
.then(pl.lit("Efficient"))
|
||||
.otherwise(pl.col("network").cast(pl.Utf8))
|
||||
.alias("net_label")
|
||||
)
|
||||
|
||||
|
||||
def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame:
|
||||
return df.filter(
|
||||
(pl.col("model") == "deepsad")
|
||||
& (pl.col("eval").is_in(EVALS))
|
||||
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
||||
& (pl.col("net_label").is_in(["LeNet", "Efficient"]))
|
||||
).select(
|
||||
"eval",
|
||||
"net_label",
|
||||
"latent_dim",
|
||||
"semi_normals",
|
||||
"semi_anomalous",
|
||||
"fold",
|
||||
"ap",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Agg:
|
||||
mean: float
|
||||
std: float
|
||||
|
||||
|
||||
def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]:
|
||||
out: Dict[Tuple[str, str, int, int, int], Agg] = {}
|
||||
gb = (
|
||||
df.group_by(
|
||||
["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"]
|
||||
)
|
||||
.agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std"))
|
||||
.to_dicts()
|
||||
)
|
||||
for row in gb:
|
||||
key = (
|
||||
str(row["eval"]),
|
||||
str(row["net_label"]),
|
||||
int(row["latent_dim"]),
|
||||
int(row["semi_normals"]),
|
||||
int(row["semi_anomalous"]),
|
||||
)
|
||||
m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan
|
||||
s = float(row["std"]) if row["std"] == row["std"] else np.nan
|
||||
out[key] = Agg(mean=m, std=s)
|
||||
return out
|
||||
|
||||
|
||||
def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if len(xs) < 2:
|
||||
return np.array(xs, dtype=float), np.array(ys, dtype=float)
|
||||
x = np.array(xs, dtype=float)
|
||||
y = np.array(ys, dtype=float)
|
||||
a, b = np.polyfit(x, y, 1)
|
||||
x_fit = np.linspace(x.min(), x.max(), 200)
|
||||
y_fit = a * x_fit + b
|
||||
return x_fit, y_fit
|
||||
|
||||
|
||||
def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]:
|
||||
vals = np.array(all_vals, dtype=float)
|
||||
errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals)
|
||||
valid = np.isfinite(vals)
|
||||
if not np.any(valid):
|
||||
return (0.0, 1.0)
|
||||
v, e = vals[valid], errs[valid]
|
||||
lo = np.min(v - e)
|
||||
hi = np.max(v + e)
|
||||
span = max(1e-3, hi - lo)
|
||||
pad = 0.08 * span
|
||||
y0 = max(0.0, lo - pad)
|
||||
y1 = min(1.0, hi + pad)
|
||||
if (y1 - y0) < 0.08:
|
||||
mid = 0.5 * (y0 + y1)
|
||||
y0 = max(0.0, mid - 0.04)
|
||||
y1 = min(1.0, mid + 0.04)
|
||||
return (float(y0), float(y1))
|
||||
|
||||
|
||||
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
||||
fig, axes = plt.subplots(
|
||||
len(SEMI_LABELING_REGIMES),
|
||||
1,
|
||||
figsize=FIGSIZE,
|
||||
constrained_layout=True,
|
||||
sharex=True,
|
||||
)
|
||||
|
||||
if len(SEMI_LABELING_REGIMES) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
|
||||
semi_n, semi_a = regime
|
||||
data = {}
|
||||
for net in ["LeNet", "Efficient"]:
|
||||
xs, ys, es = [], [], []
|
||||
for dim in LATENT_DIMS:
|
||||
key = (ev, net, dim, semi_n, semi_a)
|
||||
if key in agg:
|
||||
xs.append(dim)
|
||||
ys.append(agg[key].mean)
|
||||
es.append(agg[key].std)
|
||||
data[net] = (xs, ys, es)
|
||||
|
||||
for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]:
|
||||
xs, ys, es = data[net]
|
||||
if not xs:
|
||||
continue
|
||||
ax.set_xticks(LATENT_DIMS)
|
||||
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
|
||||
ax.scatter(
|
||||
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
|
||||
)
|
||||
x_fit, y_fit = _lin_trend(xs, ys)
|
||||
ax.plot(
|
||||
x_fit,
|
||||
y_fit,
|
||||
color=color,
|
||||
linewidth=TREND_LINEWIDTH,
|
||||
label=f"{net} (trend)",
|
||||
)
|
||||
if SHOW_STD_BANDS and es and np.any(np.isfinite(es)):
|
||||
ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0)
|
||||
yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0)
|
||||
ax.fill_between(
|
||||
xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0
|
||||
)
|
||||
|
||||
all_vals, all_errs = [], []
|
||||
for net in ["LeNet", "Efficient"]:
|
||||
_, ys, es = data[net]
|
||||
all_vals.extend(ys)
|
||||
all_errs.extend(es)
|
||||
y0, y1 = _dynamic_ylim(all_vals, all_errs)
|
||||
ax.set_ylim(y0, y1)
|
||||
|
||||
ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11)
|
||||
ax.grid(True, alpha=0.35)
|
||||
|
||||
axes[-1].set_xlabel("Latent dimension")
|
||||
for ax in axes:
|
||||
ax.set_ylabel("AP")
|
||||
|
||||
handles, labels = axes[0].get_legend_handles_labels()
|
||||
fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97))
|
||||
fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05)
|
||||
|
||||
fname = f"ap_trends_{ev}.png"
|
||||
fig.savefig(outdir / fname, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
for ev in EVALS:
|
||||
plot_eval(ev, agg, outdir)
|
||||
|
||||
|
||||
def main():
|
||||
df = load_results_dataframe(ROOT, allow_cache=True)
|
||||
df = _with_net_label(df)
|
||||
df = _filter_deepsad(df)
|
||||
agg = aggregate_ap(df)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
archive_dir = OUTPUT_DIR / "archive"
|
||||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plot_all(agg, ts_dir)
|
||||
|
||||
try:
|
||||
script_path = Path(__file__)
|
||||
shutil.copy2(script_path, ts_dir / script_path.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latest = OUTPUT_DIR / "latest"
|
||||
latest.mkdir(parents=True, exist_ok=True)
|
||||
for f in latest.iterdir():
|
||||
if f.is_file():
|
||||
f.unlink()
|
||||
for f in ts_dir.iterdir():
|
||||
if f.is_file():
|
||||
shutil.copy2(f, latest / f.name)
|
||||
|
||||
print(f"Saved plots to: {ts_dir}")
|
||||
print(f"Also updated: {latest}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
260
tools/plot_scripts/results_ap_over_semi.py
Normal file
260
tools/plot_scripts/results_ap_over_semi.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
# =========================
|
||||
# Config
|
||||
# =========================
|
||||
ROOT = Path("/home/fedex/mt/results/copy")
|
||||
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ap_over_semi")
|
||||
|
||||
# Labeling regimes (shown as separate subplots)
|
||||
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
|
||||
|
||||
# Evaluations: separate figure per eval
|
||||
EVALS: list[str] = ["exp_based", "manual_based"]
|
||||
|
||||
# X-axis (latent dims)
|
||||
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
|
||||
LATENT_DIM: int = [32, 64, 128, 256, 512, 768, 1024]
|
||||
|
||||
# Visual style
|
||||
FIGSIZE = (8, 8) # one tall figure with 3 compact subplots
|
||||
MARKERSIZE = 7
|
||||
SCATTER_ALPHA = 0.95
|
||||
LINEWIDTH = 2.0
|
||||
TREND_LINEWIDTH = 2.2
|
||||
BAND_ALPHA = 0.18
|
||||
|
||||
# Toggle: show ±1 std bands (k-fold variability)
|
||||
SHOW_STD_BANDS = True # <<< set to False to hide the bands
|
||||
|
||||
# Colors for the two DeepSAD backbones
|
||||
COLOR_LENET = "#1f77b4" # blue
|
||||
COLOR_EFFICIENT = "#ff7f0e" # orange
|
||||
|
||||
# =========================
|
||||
# Loader
|
||||
# =========================
|
||||
from load_results import load_results_dataframe
|
||||
|
||||
|
||||
# =========================
|
||||
# Helpers
|
||||
# =========================
|
||||
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
|
||||
return df.with_columns(
|
||||
pl.when(
|
||||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
||||
)
|
||||
.then(pl.lit("LeNet"))
|
||||
.when(
|
||||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
||||
)
|
||||
.then(pl.lit("Efficient"))
|
||||
.otherwise(pl.col("network").cast(pl.Utf8))
|
||||
.alias("net_label")
|
||||
)
|
||||
|
||||
|
||||
def _filter_deepsad(df: pl.DataFrame) -> pl.DataFrame:
|
||||
return df.filter(
|
||||
(pl.col("model") == "deepsad")
|
||||
& (pl.col("eval").is_in(EVALS))
|
||||
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
||||
& (pl.col("net_label").is_in(["LeNet", "Efficient"]))
|
||||
).select(
|
||||
"eval",
|
||||
"net_label",
|
||||
"latent_dim",
|
||||
"semi_normals",
|
||||
"semi_anomalous",
|
||||
"fold",
|
||||
"ap",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Agg:
|
||||
mean: float
|
||||
std: float
|
||||
|
||||
|
||||
def aggregate_ap(df: pl.DataFrame) -> Dict[Tuple[str, str, int, int, int], Agg]:
|
||||
out: Dict[Tuple[str, str, int, int, int], Agg] = {}
|
||||
gb = (
|
||||
df.group_by(
|
||||
["eval", "net_label", "latent_dim", "semi_normals", "semi_anomalous"]
|
||||
)
|
||||
.agg(pl.col("ap").mean().alias("mean"), pl.col("ap").std().alias("std"))
|
||||
.to_dicts()
|
||||
)
|
||||
for row in gb:
|
||||
key = (
|
||||
str(row["eval"]),
|
||||
str(row["net_label"]),
|
||||
int(row["latent_dim"]),
|
||||
int(row["semi_normals"]),
|
||||
int(row["semi_anomalous"]),
|
||||
)
|
||||
m = float(row["mean"]) if row["mean"] == row["mean"] else np.nan
|
||||
s = float(row["std"]) if row["std"] == row["std"] else np.nan
|
||||
out[key] = Agg(mean=m, std=s)
|
||||
return out
|
||||
|
||||
|
||||
def _lin_trend(xs: List[int], ys: List[float]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if len(xs) < 2:
|
||||
return np.array(xs, dtype=float), np.array(ys, dtype=float)
|
||||
x = np.array(xs, dtype=float)
|
||||
y = np.array(ys, dtype=float)
|
||||
a, b = np.polyfit(x, y, 1)
|
||||
x_fit = np.linspace(x.min(), x.max(), 200)
|
||||
y_fit = a * x_fit + b
|
||||
return x_fit, y_fit
|
||||
|
||||
|
||||
def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float, float]:
|
||||
vals = np.array(all_vals, dtype=float)
|
||||
errs = np.array(all_errs, dtype=float) if SHOW_STD_BANDS else np.zeros_like(vals)
|
||||
valid = np.isfinite(vals)
|
||||
if not np.any(valid):
|
||||
return (0.0, 1.0)
|
||||
v, e = vals[valid], errs[valid]
|
||||
lo = np.min(v - e)
|
||||
hi = np.max(v + e)
|
||||
span = max(1e-3, hi - lo)
|
||||
pad = 0.08 * span
|
||||
y0 = max(0.0, lo - pad)
|
||||
y1 = min(1.0, hi + pad)
|
||||
if (y1 - y0) < 0.08:
|
||||
mid = 0.5 * (y0 + y1)
|
||||
y0 = max(0.0, mid - 0.04)
|
||||
y1 = min(1.0, mid + 0.04)
|
||||
return (float(y0), float(y1))
|
||||
|
||||
|
||||
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
||||
fig, axes = plt.subplots(
|
||||
len(SEMI_LABELING_REGIMES),
|
||||
1,
|
||||
figsize=FIGSIZE,
|
||||
constrained_layout=True,
|
||||
sharex=True,
|
||||
)
|
||||
|
||||
if len(SEMI_LABELING_REGIMES) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
|
||||
semi_n, semi_a = regime
|
||||
data = {}
|
||||
for net in ["LeNet", "Efficient"]:
|
||||
xs, ys, es = [], [], []
|
||||
for dim in LATENT_DIMS:
|
||||
key = (ev, net, dim, semi_n, semi_a)
|
||||
if key in agg:
|
||||
xs.append(dim)
|
||||
ys.append(agg[key].mean)
|
||||
es.append(agg[key].std)
|
||||
data[net] = (xs, ys, es)
|
||||
|
||||
for net, color in [("LeNet", COLOR_LENET), ("Efficient", COLOR_EFFICIENT)]:
|
||||
xs, ys, es = data[net]
|
||||
if not xs:
|
||||
continue
|
||||
ax.set_xticks(LATENT_DIMS)
|
||||
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
|
||||
ax.scatter(
|
||||
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
|
||||
)
|
||||
x_fit, y_fit = _lin_trend(xs, ys)
|
||||
ax.plot(
|
||||
x_fit,
|
||||
y_fit,
|
||||
color=color,
|
||||
linewidth=TREND_LINEWIDTH,
|
||||
label=f"{net} (trend)",
|
||||
)
|
||||
if SHOW_STD_BANDS and es and np.any(np.isfinite(es)):
|
||||
ylo = np.clip(np.array(ys) - np.array(es), 0.0, 1.0)
|
||||
yhi = np.clip(np.array(ys) + np.array(es), 0.0, 1.0)
|
||||
ax.fill_between(
|
||||
xs, ylo, yhi, color=color, alpha=BAND_ALPHA, linewidth=0
|
||||
)
|
||||
|
||||
all_vals, all_errs = [], []
|
||||
for net in ["LeNet", "Efficient"]:
|
||||
_, ys, es = data[net]
|
||||
all_vals.extend(ys)
|
||||
all_errs.extend(es)
|
||||
y0, y1 = _dynamic_ylim(all_vals, all_errs)
|
||||
ax.set_ylim(y0, y1)
|
||||
|
||||
ax.set_title(f"Labeling regime {semi_n}/{semi_a}", fontsize=11)
|
||||
ax.grid(True, alpha=0.35)
|
||||
|
||||
axes[-1].set_xlabel("Latent dimension")
|
||||
for ax in axes:
|
||||
ax.set_ylabel("AP")
|
||||
|
||||
handles, labels = axes[0].get_legend_handles_labels()
|
||||
fig.legend(handles, labels, ncol=2, loc="upper center", bbox_to_anchor=(0.75, 0.97))
|
||||
fig.suptitle(f"AP vs. Latent Dimensionality — {ev.replace('_', ' ')}", y=1.05)
|
||||
|
||||
fname = f"ap_trends_{ev}.png"
|
||||
fig.savefig(outdir / fname, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_all(agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
for ev in EVALS:
|
||||
plot_eval(ev, agg, outdir)
|
||||
|
||||
|
||||
def main():
|
||||
df = load_results_dataframe(ROOT, allow_cache=True)
|
||||
df = _with_net_label(df)
|
||||
df = _filter_deepsad(df)
|
||||
agg = aggregate_ap(df)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
archive_dir = OUTPUT_DIR / "archive"
|
||||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plot_all(agg, ts_dir)
|
||||
|
||||
try:
|
||||
script_path = Path(__file__)
|
||||
shutil.copy2(script_path, ts_dir / script_path.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latest = OUTPUT_DIR / "latest"
|
||||
latest.mkdir(parents=True, exist_ok=True)
|
||||
for f in latest.iterdir():
|
||||
if f.is_file():
|
||||
f.unlink()
|
||||
for f in ts_dir.iterdir():
|
||||
if f.is_file():
|
||||
shutil.copy2(f, latest / f.name)
|
||||
|
||||
print(f"Saved plots to: {ts_dir}")
|
||||
print(f"Also updated: {latest}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,11 +8,11 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from matplotlib.lines import Line2D
|
||||
from scipy.stats import sem, t
|
||||
|
||||
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
|
||||
from plot_scripts.load_results import load_results_dataframe
|
||||
from load_results import load_results_dataframe
|
||||
from matplotlib.lines import Line2D
|
||||
from scipy.stats import sem, t
|
||||
|
||||
# ---------------------------------
|
||||
# Config
|
||||
@@ -23,6 +23,10 @@ OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
|
||||
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
||||
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
||||
EVALS = ["exp_based", "manual_based"]
|
||||
EVALS_LABELS = {
|
||||
"exp_based": "Experiment-Based Labels",
|
||||
"manual_based": "Manually-Labeled",
|
||||
}
|
||||
|
||||
# Interp grids
|
||||
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
||||
@@ -30,6 +34,10 @@ PRC_GRID = np.linspace(0.0, 1.0, 200)
|
||||
|
||||
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
|
||||
BASELINE_NET = "Efficient"
|
||||
BASELINE_LABELS = {
|
||||
"isoforest": "Isolation Forest",
|
||||
"ocsvm": "One-Class SVM",
|
||||
}
|
||||
|
||||
# Colors/styles
|
||||
COLOR_BASELINES = {
|
||||
@@ -147,12 +155,8 @@ def _select_rows(
|
||||
return df.filter(pl.all_horizontal(exprs))
|
||||
|
||||
|
||||
def _auc_list(sub: pl.DataFrame) -> list[float]:
|
||||
return [x for x in sub.select("auc").to_series().to_list() if x is not None]
|
||||
|
||||
|
||||
def _ap_list(sub: pl.DataFrame) -> list[float]:
|
||||
return [x for x in sub.select("ap").to_series().to_list() if x is not None]
|
||||
def _auc_list(sub: pl.DataFrame, kind: str) -> list[float]:
|
||||
return [x for x in sub.select(f"{kind}_auc").to_series().to_list() if x is not None]
|
||||
|
||||
|
||||
def _plot_panel(
|
||||
@@ -165,7 +169,7 @@ def _plot_panel(
|
||||
kind: str,
|
||||
):
|
||||
"""
|
||||
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient).
|
||||
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + Baselines (from Efficient).
|
||||
Legend entries include mean±CI of AUC/AP.
|
||||
"""
|
||||
ax.grid(True, alpha=0.3)
|
||||
@@ -200,9 +204,9 @@ def _plot_panel(
|
||||
continue
|
||||
|
||||
# Metric for legend
|
||||
metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b)
|
||||
metric_vals = _auc_list(sub_b, kind)
|
||||
m, ci = mean_ci(metric_vals)
|
||||
lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||
lab = f"{BASELINE_LABELS[model]}\n(AUC={m:.3f}±{ci:.3f})"
|
||||
|
||||
color = COLOR_BASELINES[model]
|
||||
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
|
||||
@@ -230,9 +234,9 @@ def _plot_panel(
|
||||
if np.all(np.isnan(mean_y)):
|
||||
continue
|
||||
|
||||
metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d)
|
||||
metric_vals = _auc_list(sub_d, kind)
|
||||
m, ci = mean_ci(metric_vals)
|
||||
lab = f"DeepSAD {net_for_deepsad} — semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||
lab = f"DeepSAD {net_for_deepsad} — {sn}/{sa}\n(AUC={m:.3f}±{ci:.3f})"
|
||||
|
||||
color = COLOR_REGIMES[regime]
|
||||
ls = LINESTYLES[regime]
|
||||
@@ -246,7 +250,7 @@ def _plot_panel(
|
||||
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
|
||||
|
||||
# Legend
|
||||
ax.legend(loc="lower right", fontsize=9, frameon=True)
|
||||
ax.legend(loc="upper right", fontsize=9, frameon=True)
|
||||
|
||||
|
||||
def make_figures_for_dim(
|
||||
@@ -254,9 +258,11 @@ def make_figures_for_dim(
|
||||
):
|
||||
# ROC: 2×1
|
||||
fig_roc, axes = plt.subplots(
|
||||
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
||||
)
|
||||
fig_roc.suptitle(
|
||||
f"ROC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
||||
)
|
||||
fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||
|
||||
_plot_panel(
|
||||
axes[0],
|
||||
@@ -266,7 +272,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="roc",
|
||||
)
|
||||
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
||||
|
||||
_plot_panel(
|
||||
axes[1],
|
||||
@@ -276,7 +282,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="roc",
|
||||
)
|
||||
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
||||
|
||||
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
|
||||
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
|
||||
@@ -284,9 +290,11 @@ def make_figures_for_dim(
|
||||
|
||||
# PRC: 2×1
|
||||
fig_prc, axes = plt.subplots(
|
||||
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
||||
)
|
||||
fig_prc.suptitle(
|
||||
f"PRC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
||||
)
|
||||
fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||
|
||||
_plot_panel(
|
||||
axes[0],
|
||||
@@ -296,7 +304,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="prc",
|
||||
)
|
||||
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
||||
|
||||
_plot_panel(
|
||||
axes[1],
|
||||
@@ -306,7 +314,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="prc",
|
||||
)
|
||||
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
||||
|
||||
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
|
||||
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")
|
||||
|
||||
Reference in New Issue
Block a user