307 lines
9.7 KiB
Python
307 lines
9.7 KiB
Python
|
|
# ae_losses_table_from_df.py
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import shutil
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Tuple
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import polars as pl
|
||
|
|
|
||
|
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||
|
|
from load_results import load_pretraining_results_dataframe
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Config
|
||
|
|
# ----------------------------
|
||
|
|
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
|
||
|
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ae_table")
|
||
|
|
|
||
|
|
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
|
||
|
|
LABEL_FIELD = "labels_exp_based"
|
||
|
|
|
||
|
|
# Which architectures to include (labels must match canonicalize_network)
|
||
|
|
WANTED_NETS = {"LeNet", "Efficient"}
|
||
|
|
|
||
|
|
# Formatting
|
||
|
|
DECIMALS = 4 # how many decimals to display for losses
|
||
|
|
BOLD_BEST = False # set True to bold per-group best (lower is better)
|
||
|
|
LOWER_IS_BETTER = True # for losses we want the minimum
|
||
|
|
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Helpers (ported/minified from your plotting script)
|
||
|
|
# ----------------------------
|
||
|
|
def canonicalize_network(name: str) -> str:
|
||
|
|
low = (name or "").lower()
|
||
|
|
if "lenet" in low:
|
||
|
|
return "LeNet"
|
||
|
|
if "efficient" in low:
|
||
|
|
return "Efficient"
|
||
|
|
return name or "unknown"
|
||
|
|
|
||
|
|
|
||
|
|
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
|
||
|
|
n = len(scores)
|
||
|
|
if n == 0:
|
||
|
|
return np.nan
|
||
|
|
if batch_size <= 0:
|
||
|
|
batch_size = n
|
||
|
|
n_batches = (n + batch_size - 1) // batch_size
|
||
|
|
acc = 0.0
|
||
|
|
for i in range(0, n, batch_size):
|
||
|
|
acc += float(np.mean(scores[i : i + batch_size]))
|
||
|
|
return acc / n_batches
|
||
|
|
|
||
|
|
|
||
|
|
def extract_batch_size(cfg_json: str) -> int:
|
||
|
|
import json
|
||
|
|
|
||
|
|
try:
|
||
|
|
cfg = json.loads(cfg_json) if cfg_json else {}
|
||
|
|
except Exception:
|
||
|
|
cfg = {}
|
||
|
|
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class Cell:
|
||
|
|
mean: float | None
|
||
|
|
std: float | None
|
||
|
|
|
||
|
|
|
||
|
|
def _fmt(mean: float | None) -> str:
|
||
|
|
return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}"
|
||
|
|
|
||
|
|
|
||
|
|
def _bold_mask_display(
|
||
|
|
values: List[float | None], decimals: int, lower_is_better: bool
|
||
|
|
) -> List[bool]:
|
||
|
|
"""
|
||
|
|
Tie-aware bolding mask based on *displayed* precision.
|
||
|
|
For losses, lower is better (min). For metrics where higher is better, set lower_is_better=False.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def disp(v: float | None) -> float | None:
|
||
|
|
if v is None or not (v == v):
|
||
|
|
return None
|
||
|
|
# use string → float to match display rounding exactly
|
||
|
|
return float(f"{v:.{decimals}f}")
|
||
|
|
|
||
|
|
rounded = [disp(v) for v in values]
|
||
|
|
finite = [v for v in rounded if v is not None]
|
||
|
|
if not finite:
|
||
|
|
return [False] * len(values)
|
||
|
|
target = min(finite) if lower_is_better else max(finite)
|
||
|
|
return [(v is not None and v == target) for v in rounded]
|
||
|
|
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Core
|
||
|
|
# ----------------------------
|
||
|
|
def build_losses_table_from_df(
|
||
|
|
df: pl.DataFrame, label_field: str
|
||
|
|
) -> Tuple[str, float | None]:
|
||
|
|
"""
|
||
|
|
Build a LaTeX table showing Overall loss (LeNet, Efficient) and Anomaly loss (LeNet, Efficient)
|
||
|
|
with one row per latent dimension. Returns (latex_table_string, max_std_overall).
|
||
|
|
"""
|
||
|
|
# Basic validation
|
||
|
|
required_cols = {"scores", "network", "latent_dim"}
|
||
|
|
missing = required_cols - set(df.columns)
|
||
|
|
if missing:
|
||
|
|
raise ValueError(f"Missing required columns in AE dataframe: {missing}")
|
||
|
|
if label_field not in df.columns:
|
||
|
|
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
|
||
|
|
|
||
|
|
# Canonicalize nets, compute per-row overall/anomaly losses
|
||
|
|
rows: List[dict] = []
|
||
|
|
for row in df.iter_rows(named=True):
|
||
|
|
net = canonicalize_network(row["network"])
|
||
|
|
if WANTED_NETS and net not in WANTED_NETS:
|
||
|
|
continue
|
||
|
|
dim = int(row["latent_dim"])
|
||
|
|
batch_size = extract_batch_size(row.get("config_json"))
|
||
|
|
scores = np.asarray(row["scores"] or [], dtype=float)
|
||
|
|
|
||
|
|
labels = row.get(label_field)
|
||
|
|
labels = np.asarray(labels, dtype=int) if labels is not None else None
|
||
|
|
|
||
|
|
overall_loss = calculate_batch_mean_loss(scores, batch_size)
|
||
|
|
|
||
|
|
anomaly_loss = np.nan
|
||
|
|
if labels is not None and labels.size == scores.size:
|
||
|
|
anomaly_scores = scores[labels == -1]
|
||
|
|
if anomaly_scores.size > 0:
|
||
|
|
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
|
||
|
|
|
||
|
|
rows.append(
|
||
|
|
{
|
||
|
|
"net": net,
|
||
|
|
"latent_dim": dim,
|
||
|
|
"overall": overall_loss,
|
||
|
|
"anomaly": anomaly_loss,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
if not rows:
|
||
|
|
raise ValueError(
|
||
|
|
"No rows available after filtering; check WANTED_NETS or input data."
|
||
|
|
)
|
||
|
|
|
||
|
|
df2 = pl.DataFrame(rows)
|
||
|
|
|
||
|
|
# Aggregate across folds per (net, latent_dim)
|
||
|
|
agg = df2.group_by(["net", "latent_dim"]).agg(
|
||
|
|
pl.col("overall").mean().alias("overall_mean"),
|
||
|
|
pl.col("overall").std().alias("overall_std"),
|
||
|
|
pl.col("anomaly").mean().alias("anomaly_mean"),
|
||
|
|
pl.col("anomaly").std().alias("anomaly_std"),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Collect union of dims across both nets
|
||
|
|
dims = sorted(set(agg.get_column("latent_dim").to_list()))
|
||
|
|
|
||
|
|
# Build lookup
|
||
|
|
keymap: Dict[Tuple[str, int], Cell] = {}
|
||
|
|
keymap_anom: Dict[Tuple[str, int], Cell] = {}
|
||
|
|
|
||
|
|
max_std: float | None = None
|
||
|
|
|
||
|
|
def push_std(v: float | None):
|
||
|
|
nonlocal max_std
|
||
|
|
if v is None or not (v == v):
|
||
|
|
return
|
||
|
|
if max_std is None or v > max_std:
|
||
|
|
max_std = v
|
||
|
|
|
||
|
|
for r in agg.iter_rows(named=True):
|
||
|
|
k = (r["net"], int(r["latent_dim"]))
|
||
|
|
keymap[k] = Cell(r.get("overall_mean"), r.get("overall_std"))
|
||
|
|
keymap_anom[k] = Cell(r.get("anomaly_mean"), r.get("anomaly_std"))
|
||
|
|
push_std(r.get("overall_std"))
|
||
|
|
push_std(r.get("anomaly_std"))
|
||
|
|
|
||
|
|
# Ensure nets order consistent
|
||
|
|
nets_order = ["LeNet", "Efficient"]
|
||
|
|
nets_present = [n for n in nets_order if any(k[0] == n for k in keymap.keys())]
|
||
|
|
if not nets_present:
|
||
|
|
nets_present = sorted({k[0] for k in keymap.keys()})
|
||
|
|
|
||
|
|
# Build LaTeX table
|
||
|
|
header_left = [r"LeNet", r"Efficient"]
|
||
|
|
header_right = [r"LeNet", r"Efficient"]
|
||
|
|
|
||
|
|
lines: List[str] = []
|
||
|
|
lines.append(r"\begin{table}[t]")
|
||
|
|
lines.append(r"\centering")
|
||
|
|
lines.append(r"\setlength{\tabcolsep}{4pt}")
|
||
|
|
lines.append(r"\renewcommand{\arraystretch}{1.2}")
|
||
|
|
# vertical bar between the two groups
|
||
|
|
lines.append(r"\begin{tabularx}{\textwidth}{c*{2}{Y}|*{2}{Y}}")
|
||
|
|
lines.append(r"\toprule")
|
||
|
|
lines.append(
|
||
|
|
r" & \multicolumn{2}{c}{Overall loss} & \multicolumn{2}{c}{Anomaly loss} \\"
|
||
|
|
)
|
||
|
|
lines.append(r"\cmidrule(lr){2-3} \cmidrule(lr){4-5}")
|
||
|
|
lines.append(
|
||
|
|
r"Latent Dim. & "
|
||
|
|
+ " & ".join(header_left)
|
||
|
|
+ " & "
|
||
|
|
+ " & ".join(header_right)
|
||
|
|
+ r" \\"
|
||
|
|
)
|
||
|
|
lines.append(r"\midrule")
|
||
|
|
|
||
|
|
for d in dims:
|
||
|
|
# Gather values in order: Overall (LeNet, Efficient), Anomaly (LeNet, Efficient)
|
||
|
|
overall_vals = [keymap.get((n, d), Cell(None, None)).mean for n in nets_present]
|
||
|
|
anomaly_vals = [
|
||
|
|
keymap_anom.get((n, d), Cell(None, None)).mean for n in nets_present
|
||
|
|
]
|
||
|
|
overall_strs = [_fmt(v) for v in overall_vals]
|
||
|
|
anomaly_strs = [_fmt(v) for v in anomaly_vals]
|
||
|
|
|
||
|
|
if BOLD_BEST:
|
||
|
|
mask_overall = _bold_mask_display(overall_vals, DECIMALS, LOWER_IS_BETTER)
|
||
|
|
mask_anom = _bold_mask_display(anomaly_vals, DECIMALS, LOWER_IS_BETTER)
|
||
|
|
overall_strs = [
|
||
|
|
(r"\textbf{" + s + "}") if (m and s != "--") else s
|
||
|
|
for s, m in zip(overall_strs, mask_overall)
|
||
|
|
]
|
||
|
|
anomaly_strs = [
|
||
|
|
(r"\textbf{" + s + "}") if (m and s != "--") else s
|
||
|
|
for s, m in zip(anomaly_strs, mask_anom)
|
||
|
|
]
|
||
|
|
|
||
|
|
lines.append(
|
||
|
|
f"{d} & "
|
||
|
|
+ " & ".join(overall_strs)
|
||
|
|
+ " & "
|
||
|
|
+ " & ".join(anomaly_strs)
|
||
|
|
+ r" \\"
|
||
|
|
)
|
||
|
|
|
||
|
|
lines.append(r"\bottomrule")
|
||
|
|
lines.append(r"\end{tabularx}")
|
||
|
|
|
||
|
|
max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}"
|
||
|
|
lines.append(
|
||
|
|
rf"\caption{{Autoencoder pre-training MSE losses (test split) across latent dimensions. "
|
||
|
|
rf"Left: overall loss; Right: anomaly-only loss. "
|
||
|
|
rf"Cells show means across folds (no $\pm$std). "
|
||
|
|
rf"Maximum observed standard deviation across all cells (not shown): {max_std_str}.}}"
|
||
|
|
)
|
||
|
|
lines.append(r"\end{table}")
|
||
|
|
|
||
|
|
return "\n".join(lines), max_std
|
||
|
|
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Entry
|
||
|
|
# ----------------------------
|
||
|
|
def main():
|
||
|
|
df = load_pretraining_results_dataframe(ROOT, allow_cache=True)
|
||
|
|
|
||
|
|
# Build LaTeX table
|
||
|
|
tex, max_std = build_losses_table_from_df(df, LABEL_FIELD)
|
||
|
|
|
||
|
|
# Output dirs
|
||
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
|
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
|
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
out_name = "ae_pretraining_losses_table.tex"
|
||
|
|
out_path = ts_dir / out_name
|
||
|
|
out_path.write_text(tex, encoding="utf-8")
|
||
|
|
|
||
|
|
# Save a copy of this script
|
||
|
|
script_path = Path(__file__)
|
||
|
|
try:
|
||
|
|
shutil.copy2(script_path, ts_dir / script_path.name)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Mirror latest
|
||
|
|
latest = OUTPUT_DIR / "latest"
|
||
|
|
latest.mkdir(parents=True, exist_ok=True)
|
||
|
|
# Clear
|
||
|
|
for f in latest.iterdir():
|
||
|
|
if f.is_file():
|
||
|
|
f.unlink()
|
||
|
|
# Copy
|
||
|
|
for f in ts_dir.iterdir():
|
||
|
|
if f.is_file():
|
||
|
|
shutil.copy2(f, latest / f.name)
|
||
|
|
|
||
|
|
print(f"Saved table to: {ts_dir}")
|
||
|
|
print(f"Also updated: {latest}")
|
||
|
|
print(f" - {out_name}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|