Files
mt/tools/plot_scripts/results_ae_table.py

307 lines
9.7 KiB
Python
Raw Permalink Normal View History

2025-09-18 11:58:28 +02:00
# ae_losses_table_from_df.py
from __future__ import annotations
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import polars as pl
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
from load_results import load_pretraining_results_dataframe
# ----------------------------
# Config
# ----------------------------
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ae_table")
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
LABEL_FIELD = "labels_exp_based"
# Which architectures to include (labels must match canonicalize_network)
WANTED_NETS = {"LeNet", "Efficient"}
# Formatting
DECIMALS = 4 # how many decimals to display for losses
BOLD_BEST = False # set True to bold per-group best (lower is better)
LOWER_IS_BETTER = True # for losses we want the minimum
# ----------------------------
# Helpers (ported/minified from your plotting script)
# ----------------------------
def canonicalize_network(name: str) -> str:
low = (name or "").lower()
if "lenet" in low:
return "LeNet"
if "efficient" in low:
return "Efficient"
return name or "unknown"
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
n = len(scores)
if n == 0:
return np.nan
if batch_size <= 0:
batch_size = n
n_batches = (n + batch_size - 1) // batch_size
acc = 0.0
for i in range(0, n, batch_size):
acc += float(np.mean(scores[i : i + batch_size]))
return acc / n_batches
def extract_batch_size(cfg_json: str) -> int:
import json
try:
cfg = json.loads(cfg_json) if cfg_json else {}
except Exception:
cfg = {}
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
@dataclass(frozen=True)
class Cell:
mean: float | None
std: float | None
def _fmt(mean: float | None) -> str:
return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}"
def _bold_mask_display(
values: List[float | None], decimals: int, lower_is_better: bool
) -> List[bool]:
"""
Tie-aware bolding mask based on *displayed* precision.
For losses, lower is better (min). For metrics where higher is better, set lower_is_better=False.
"""
def disp(v: float | None) -> float | None:
if v is None or not (v == v):
return None
# use string → float to match display rounding exactly
return float(f"{v:.{decimals}f}")
rounded = [disp(v) for v in values]
finite = [v for v in rounded if v is not None]
if not finite:
return [False] * len(values)
target = min(finite) if lower_is_better else max(finite)
return [(v is not None and v == target) for v in rounded]
# ----------------------------
# Core
# ----------------------------
def build_losses_table_from_df(
df: pl.DataFrame, label_field: str
) -> Tuple[str, float | None]:
"""
Build a LaTeX table showing Overall loss (LeNet, Efficient) and Anomaly loss (LeNet, Efficient)
with one row per latent dimension. Returns (latex_table_string, max_std_overall).
"""
# Basic validation
required_cols = {"scores", "network", "latent_dim"}
missing = required_cols - set(df.columns)
if missing:
raise ValueError(f"Missing required columns in AE dataframe: {missing}")
if label_field not in df.columns:
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
# Canonicalize nets, compute per-row overall/anomaly losses
rows: List[dict] = []
for row in df.iter_rows(named=True):
net = canonicalize_network(row["network"])
if WANTED_NETS and net not in WANTED_NETS:
continue
dim = int(row["latent_dim"])
batch_size = extract_batch_size(row.get("config_json"))
scores = np.asarray(row["scores"] or [], dtype=float)
labels = row.get(label_field)
labels = np.asarray(labels, dtype=int) if labels is not None else None
overall_loss = calculate_batch_mean_loss(scores, batch_size)
anomaly_loss = np.nan
if labels is not None and labels.size == scores.size:
anomaly_scores = scores[labels == -1]
if anomaly_scores.size > 0:
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
rows.append(
{
"net": net,
"latent_dim": dim,
"overall": overall_loss,
"anomaly": anomaly_loss,
}
)
if not rows:
raise ValueError(
"No rows available after filtering; check WANTED_NETS or input data."
)
df2 = pl.DataFrame(rows)
# Aggregate across folds per (net, latent_dim)
agg = df2.group_by(["net", "latent_dim"]).agg(
pl.col("overall").mean().alias("overall_mean"),
pl.col("overall").std().alias("overall_std"),
pl.col("anomaly").mean().alias("anomaly_mean"),
pl.col("anomaly").std().alias("anomaly_std"),
)
# Collect union of dims across both nets
dims = sorted(set(agg.get_column("latent_dim").to_list()))
# Build lookup
keymap: Dict[Tuple[str, int], Cell] = {}
keymap_anom: Dict[Tuple[str, int], Cell] = {}
max_std: float | None = None
def push_std(v: float | None):
nonlocal max_std
if v is None or not (v == v):
return
if max_std is None or v > max_std:
max_std = v
for r in agg.iter_rows(named=True):
k = (r["net"], int(r["latent_dim"]))
keymap[k] = Cell(r.get("overall_mean"), r.get("overall_std"))
keymap_anom[k] = Cell(r.get("anomaly_mean"), r.get("anomaly_std"))
push_std(r.get("overall_std"))
push_std(r.get("anomaly_std"))
# Ensure nets order consistent
nets_order = ["LeNet", "Efficient"]
nets_present = [n for n in nets_order if any(k[0] == n for k in keymap.keys())]
if not nets_present:
nets_present = sorted({k[0] for k in keymap.keys()})
# Build LaTeX table
header_left = [r"LeNet", r"Efficient"]
header_right = [r"LeNet", r"Efficient"]
lines: List[str] = []
lines.append(r"\begin{table}[t]")
lines.append(r"\centering")
lines.append(r"\setlength{\tabcolsep}{4pt}")
lines.append(r"\renewcommand{\arraystretch}{1.2}")
# vertical bar between the two groups
lines.append(r"\begin{tabularx}{\textwidth}{c*{2}{Y}|*{2}{Y}}")
lines.append(r"\toprule")
lines.append(
r" & \multicolumn{2}{c}{Overall loss} & \multicolumn{2}{c}{Anomaly loss} \\"
)
lines.append(r"\cmidrule(lr){2-3} \cmidrule(lr){4-5}")
lines.append(
r"Latent Dim. & "
+ " & ".join(header_left)
+ " & "
+ " & ".join(header_right)
+ r" \\"
)
lines.append(r"\midrule")
for d in dims:
# Gather values in order: Overall (LeNet, Efficient), Anomaly (LeNet, Efficient)
overall_vals = [keymap.get((n, d), Cell(None, None)).mean for n in nets_present]
anomaly_vals = [
keymap_anom.get((n, d), Cell(None, None)).mean for n in nets_present
]
overall_strs = [_fmt(v) for v in overall_vals]
anomaly_strs = [_fmt(v) for v in anomaly_vals]
if BOLD_BEST:
mask_overall = _bold_mask_display(overall_vals, DECIMALS, LOWER_IS_BETTER)
mask_anom = _bold_mask_display(anomaly_vals, DECIMALS, LOWER_IS_BETTER)
overall_strs = [
(r"\textbf{" + s + "}") if (m and s != "--") else s
for s, m in zip(overall_strs, mask_overall)
]
anomaly_strs = [
(r"\textbf{" + s + "}") if (m and s != "--") else s
for s, m in zip(anomaly_strs, mask_anom)
]
lines.append(
f"{d} & "
+ " & ".join(overall_strs)
+ " & "
+ " & ".join(anomaly_strs)
+ r" \\"
)
lines.append(r"\bottomrule")
lines.append(r"\end{tabularx}")
max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}"
lines.append(
rf"\caption{{Autoencoder pre-training MSE losses (test split) across latent dimensions. "
rf"Left: overall loss; Right: anomaly-only loss. "
rf"Cells show means across folds (no $\pm$std). "
rf"Maximum observed standard deviation across all cells (not shown): {max_std_str}.}}"
)
lines.append(r"\end{table}")
return "\n".join(lines), max_std
# ----------------------------
# Entry
# ----------------------------
def main():
df = load_pretraining_results_dataframe(ROOT, allow_cache=True)
# Build LaTeX table
tex, max_std = build_losses_table_from_df(df, LABEL_FIELD)
# Output dirs
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
out_name = "ae_pretraining_losses_table.tex"
out_path = ts_dir / out_name
out_path.write_text(tex, encoding="utf-8")
# Save a copy of this script
script_path = Path(__file__)
try:
shutil.copy2(script_path, ts_dir / script_path.name)
except Exception:
pass
# Mirror latest
latest = OUTPUT_DIR / "latest"
latest.mkdir(parents=True, exist_ok=True)
# Clear
for f in latest.iterdir():
if f.is_file():
f.unlink()
# Copy
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
print(f"Saved table to: {ts_dir}")
print(f"Also updated: {latest}")
print(f" - {out_name}")
if __name__ == "__main__":
main()