365 lines
11 KiB
Python
365 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import polars as pl
|
|
|
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
|
from load_results import load_results_dataframe
|
|
from matplotlib.lines import Line2D
|
|
|
|
# ----------------------------
|
|
# Config
|
|
# ----------------------------
|
|
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
|
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_comparisons")
|
|
|
|
SEMI_LABELING_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
|
|
|
# Semi-supervised setting to select
|
|
SEMI_NORMALS = 50
|
|
SEMI_ANOMALOUS = 10
|
|
|
|
# Which evaluation columns to plot
|
|
EVALS = ["exp_based", "manual_based"]
|
|
EVALS_LABELS = {
|
|
"exp_based": "Experiment-Label-Based",
|
|
"manual_based": "Manually-Labeled",
|
|
}
|
|
|
|
# Latent dimensions to show as 7 subplots
|
|
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
|
|
|
# Interpolation grids
|
|
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
|
PRC_GRID = np.linspace(0.0, 1.0, 200)
|
|
|
|
|
|
# ----------------------------
|
|
# Helpers
|
|
# ----------------------------
|
|
def canonicalize_network(name: str) -> str:
|
|
"""Map net_name strings to clean labels for plotting."""
|
|
low = (name or "").lower()
|
|
if "lenet" in low:
|
|
return "LeNet"
|
|
if "efficient" in low:
|
|
return "Efficient"
|
|
return name or "unknown"
|
|
|
|
|
|
def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray):
|
|
"""
|
|
Interpolate a list of (x, y) curves onto a common grid.
|
|
Returns mean_y, std_y on the grid. Skips empty or invalid curves.
|
|
"""
|
|
if not curves:
|
|
return np.full_like(grid, np.nan, dtype=float), np.full_like(
|
|
grid, np.nan, dtype=float
|
|
)
|
|
|
|
interps = []
|
|
for x, y in curves:
|
|
if x is None or y is None:
|
|
continue
|
|
x = np.asarray(x, dtype=float)
|
|
y = np.asarray(y, dtype=float)
|
|
if x.size == 0 or y.size == 0 or x.size != y.size:
|
|
continue
|
|
# ensure sorted by x and unique
|
|
order = np.argsort(x)
|
|
x = x[order]
|
|
y = y[order]
|
|
# deduplicate x (np.interp requires ascending x)
|
|
uniq_x, uniq_idx = np.unique(x, return_index=True)
|
|
y = y[uniq_idx]
|
|
x = uniq_x
|
|
# bound grid to valid interp range
|
|
gmin = max(grid[0], x[0])
|
|
gmax = min(grid[-1], x[-1])
|
|
g = np.clip(grid, gmin, gmax)
|
|
yi = np.interp(g, x, y)
|
|
interps.append(yi)
|
|
|
|
if not interps:
|
|
return np.full_like(grid, np.nan, dtype=float), np.full_like(
|
|
grid, np.nan, dtype=float
|
|
)
|
|
|
|
A = np.vstack(interps)
|
|
return np.nanmean(A, axis=0), np.nanstd(A, axis=0)
|
|
|
|
|
|
def _net_label_col(df: pl.DataFrame) -> pl.DataFrame:
|
|
"""Add 'net_label' column (LeNet/Efficient/fallback)."""
|
|
return df.with_columns(
|
|
pl.when(
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
|
)
|
|
.then(pl.lit("LeNet"))
|
|
.when(
|
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
|
)
|
|
.then(pl.lit("Efficient"))
|
|
.otherwise(pl.col("network").cast(pl.Utf8))
|
|
.alias("net_label")
|
|
)
|
|
|
|
|
|
def _select_rows(
|
|
df: pl.DataFrame,
|
|
*,
|
|
model: str,
|
|
eval_type: str,
|
|
latent_dim: int,
|
|
net_label: str | None,
|
|
semi_normals: int,
|
|
semi_anomalous: int,
|
|
) -> pl.DataFrame:
|
|
"""Polars filter: by model/eval/latent and optionally net_label."""
|
|
exprs = [
|
|
pl.col("model") == model,
|
|
pl.col("eval") == eval_type,
|
|
pl.col("latent_dim") == latent_dim,
|
|
pl.col("semi_normals") == semi_normals,
|
|
pl.col("semi_anomalous") == semi_anomalous,
|
|
]
|
|
if net_label is not None:
|
|
exprs.append(pl.col("net_label") == net_label)
|
|
return df.filter(pl.all_horizontal(exprs))
|
|
|
|
|
|
def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
"""
|
|
From a list of rows (Python dicts), return list of (x, y) curves for given kind.
|
|
kind: "roc" or "prc"
|
|
"""
|
|
curves = []
|
|
for r in rows:
|
|
if kind == "roc":
|
|
c = r.get("roc_curve")
|
|
if not c:
|
|
continue
|
|
x, y = c.get("fpr"), c.get("tpr")
|
|
else:
|
|
c = r.get("prc_curve")
|
|
if not c:
|
|
continue
|
|
x, y = c.get("recall"), c.get("precision")
|
|
if x is None or y is None:
|
|
continue
|
|
curves.append((np.asarray(x, dtype=float), np.asarray(y, dtype=float)))
|
|
return curves
|
|
|
|
|
|
def _ensure_dim_axes(fig_title: str):
|
|
"""Return figure, axes array laid out 2x4; last axis is for legend."""
|
|
fig, axes = plt.subplots(
|
|
nrows=4, ncols=2, figsize=(12, 16), constrained_layout=True
|
|
)
|
|
# fig.suptitle(fig_title, fontsize=14)
|
|
axes = axes.ravel()
|
|
return fig, axes
|
|
|
|
|
|
def _add_legend_to_axis(ax, handles_labels):
|
|
ax.axis("off")
|
|
handles, labels = handles_labels
|
|
ax.legend(
|
|
handles,
|
|
labels,
|
|
loc="center",
|
|
frameon=False,
|
|
ncol=1,
|
|
fontsize=11,
|
|
borderaxespad=0.5,
|
|
)
|
|
|
|
|
|
def plot_grid_from_df(
|
|
df: pl.DataFrame,
|
|
eval_type: str,
|
|
kind: str,
|
|
semi_normals: int,
|
|
semi_anomalous: int,
|
|
out_path: Path,
|
|
):
|
|
"""
|
|
Create a 2x4 grid of subplots, one per latent dim; 8th panel holds legend.
|
|
kind: 'roc' or 'prc'
|
|
"""
|
|
fig_title = f"{kind.upper()} — {EVALS_LABELS[eval_type]} (Semi-Labeling Regime = {semi_normals}/{semi_anomalous})"
|
|
fig, axes = _ensure_dim_axes(fig_title)
|
|
|
|
# plotting order & colors
|
|
series = [
|
|
(
|
|
"isoforest",
|
|
None,
|
|
"IsolationForest",
|
|
"tab:purple",
|
|
), # baselines from Efficient only (handled below)
|
|
("ocsvm", None, "OC-SVM", "tab:green"),
|
|
("deepsad", "LeNet", "DeepSAD (LeNet)", "tab:blue"),
|
|
("deepsad", "Efficient", "DeepSAD (Efficient)", "tab:orange"),
|
|
]
|
|
|
|
# Handles for legend (build from first subplot that has data)
|
|
legend_handles = []
|
|
legend_labels = []
|
|
have_legend = False
|
|
|
|
letters = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
|
|
|
for i, dim in enumerate(LATENT_DIMS):
|
|
if i >= 7:
|
|
break # last slot reserved for legend
|
|
ax = axes[i]
|
|
ax.set_title(f"({letters[i]}) Latent Dim. = {dim}")
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
if kind == "roc":
|
|
ax.set_xlim(0, 1)
|
|
ax.set_ylim(0, 1)
|
|
ax.set_xlabel("FPR")
|
|
ax.set_ylabel("TPR")
|
|
grid = ROC_GRID
|
|
else:
|
|
ax.set_xlim(0, 1)
|
|
ax.set_ylim(0, 1)
|
|
ax.set_xlabel("Recall")
|
|
ax.set_ylabel("Precision")
|
|
grid = PRC_GRID
|
|
|
|
plotted_any = False
|
|
|
|
for model, net_needed, label, color in series:
|
|
# baselines: use Efficient only
|
|
net_filter = net_needed
|
|
if model in ("isoforest", "ocsvm"):
|
|
net_filter = "Efficient"
|
|
|
|
sub = _select_rows(
|
|
df,
|
|
model=model,
|
|
eval_type=eval_type,
|
|
latent_dim=dim,
|
|
net_label=net_filter,
|
|
semi_normals=semi_normals,
|
|
semi_anomalous=semi_anomalous,
|
|
)
|
|
if sub.height == 0:
|
|
continue
|
|
|
|
rows = sub.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
|
|
|
curves = _extract_curves(rows, kind)
|
|
if not curves:
|
|
continue
|
|
|
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
|
# Guard for all-NaN
|
|
if np.all(np.isnan(mean_y)):
|
|
continue
|
|
|
|
ax.plot(grid, mean_y, label=label, color=color)
|
|
ax.fill_between(
|
|
grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color
|
|
)
|
|
plotted_any = True
|
|
|
|
if not have_legend:
|
|
legend_handles.append(Line2D([0], [0], color=color, lw=2))
|
|
legend_labels.append(label)
|
|
|
|
if not plotted_any:
|
|
ax.text(
|
|
0.5, 0.5, "No data", ha="center", va="center", fontsize=12, alpha=0.7
|
|
)
|
|
ax.set_xlim(0, 1)
|
|
ax.set_ylim(0, 1)
|
|
|
|
if not have_legend and legend_handles:
|
|
have_legend = True
|
|
|
|
# Legend in 8th slot
|
|
_add_legend_to_axis(axes[7], (legend_handles, legend_labels))
|
|
|
|
# Save
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
|
|
|
|
def main():
|
|
# Load main results DF (uses your cache if enabled in the loader)
|
|
df = load_results_dataframe(ROOT, allow_cache=True)
|
|
|
|
# Add clean network labels
|
|
complete_df = _net_label_col(df)
|
|
|
|
# Prepare output dirs
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
archive_dir = OUTPUT_DIR / "archive"
|
|
archive_dir.mkdir(parents=True, exist_ok=True)
|
|
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES:
|
|
# Restrict to our semi-supervised setting
|
|
df = complete_df.filter(
|
|
(pl.col("semi_normals") == semi_normals)
|
|
& (pl.col("semi_anomalous") == semi_anomalous)
|
|
& (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
|
|
& (pl.col("eval").is_in(EVALS))
|
|
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
|
)
|
|
|
|
# Plot 4 figures
|
|
for eval_type in EVALS:
|
|
# ROC
|
|
plot_grid_from_df(
|
|
df,
|
|
eval_type=eval_type,
|
|
kind="roc",
|
|
semi_normals=semi_normals,
|
|
semi_anomalous=semi_anomalous,
|
|
out_path=ts_dir
|
|
/ f"roc_semi_{semi_normals}_{semi_anomalous}_{eval_type}.png",
|
|
)
|
|
# PRC
|
|
plot_grid_from_df(
|
|
df,
|
|
eval_type=eval_type,
|
|
kind="prc",
|
|
semi_normals=semi_normals,
|
|
semi_anomalous=semi_anomalous,
|
|
out_path=ts_dir
|
|
/ f"prc_{semi_normals}_{semi_anomalous}_{eval_type}.png",
|
|
)
|
|
|
|
# Copy this script to preserve the code used for the outputs
|
|
script_path = Path(__file__)
|
|
shutil.copy2(script_path, ts_dir)
|
|
|
|
# Mirror latest
|
|
latest = OUTPUT_DIR / "latest"
|
|
latest.mkdir(exist_ok=True, parents=True)
|
|
for f in latest.iterdir():
|
|
if f.is_file():
|
|
f.unlink()
|
|
for f in ts_dir.iterdir():
|
|
if f.is_file():
|
|
shutil.copy2(f, latest / f.name)
|
|
|
|
print(f"Saved plots to: {ts_dir}")
|
|
print(f"Also updated: {latest}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|