# ae_losses_table_from_df.py from __future__ import annotations import shutil from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, List, Tuple import numpy as np import polars as pl # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_pretraining_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ae_table") # Which label field to use from the DF; "labels_exp_based" or "labels_manual_based" LABEL_FIELD = "labels_exp_based" # Which architectures to include (labels must match canonicalize_network) WANTED_NETS = {"LeNet", "Efficient"} # Formatting DECIMALS = 4 # how many decimals to display for losses BOLD_BEST = False # set True to bold per-group best (lower is better) LOWER_IS_BETTER = True # for losses we want the minimum # ---------------------------- # Helpers (ported/minified from your plotting script) # ---------------------------- def canonicalize_network(name: str) -> str: low = (name or "").lower() if "lenet" in low: return "LeNet" if "efficient" in low: return "Efficient" return name or "unknown" def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float: n = len(scores) if n == 0: return np.nan if batch_size <= 0: batch_size = n n_batches = (n + batch_size - 1) // batch_size acc = 0.0 for i in range(0, n, batch_size): acc += float(np.mean(scores[i : i + batch_size])) return acc / n_batches def extract_batch_size(cfg_json: str) -> int: import json try: cfg = json.loads(cfg_json) if cfg_json else {} except Exception: cfg = {} return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256) @dataclass(frozen=True) class Cell: mean: float | None std: float | None def _fmt(mean: float | None) -> str: return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}" def _bold_mask_display( values: List[float | None], decimals: int, lower_is_better: bool ) -> List[bool]: """ Tie-aware bolding mask based on *displayed* precision. For losses, lower is better (min). For metrics where higher is better, set lower_is_better=False. """ def disp(v: float | None) -> float | None: if v is None or not (v == v): return None # use string → float to match display rounding exactly return float(f"{v:.{decimals}f}") rounded = [disp(v) for v in values] finite = [v for v in rounded if v is not None] if not finite: return [False] * len(values) target = min(finite) if lower_is_better else max(finite) return [(v is not None and v == target) for v in rounded] # ---------------------------- # Core # ---------------------------- def build_losses_table_from_df( df: pl.DataFrame, label_field: str ) -> Tuple[str, float | None]: """ Build a LaTeX table showing Overall loss (LeNet, Efficient) and Anomaly loss (LeNet, Efficient) with one row per latent dimension. Returns (latex_table_string, max_std_overall). """ # Basic validation required_cols = {"scores", "network", "latent_dim"} missing = required_cols - set(df.columns) if missing: raise ValueError(f"Missing required columns in AE dataframe: {missing}") if label_field not in df.columns: raise ValueError(f"Expected '{label_field}' column in AE dataframe.") # Canonicalize nets, compute per-row overall/anomaly losses rows: List[dict] = [] for row in df.iter_rows(named=True): net = canonicalize_network(row["network"]) if WANTED_NETS and net not in WANTED_NETS: continue dim = int(row["latent_dim"]) batch_size = extract_batch_size(row.get("config_json")) scores = np.asarray(row["scores"] or [], dtype=float) labels = row.get(label_field) labels = np.asarray(labels, dtype=int) if labels is not None else None overall_loss = calculate_batch_mean_loss(scores, batch_size) anomaly_loss = np.nan if labels is not None and labels.size == scores.size: anomaly_scores = scores[labels == -1] if anomaly_scores.size > 0: anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size) rows.append( { "net": net, "latent_dim": dim, "overall": overall_loss, "anomaly": anomaly_loss, } ) if not rows: raise ValueError( "No rows available after filtering; check WANTED_NETS or input data." ) df2 = pl.DataFrame(rows) # Aggregate across folds per (net, latent_dim) agg = df2.group_by(["net", "latent_dim"]).agg( pl.col("overall").mean().alias("overall_mean"), pl.col("overall").std().alias("overall_std"), pl.col("anomaly").mean().alias("anomaly_mean"), pl.col("anomaly").std().alias("anomaly_std"), ) # Collect union of dims across both nets dims = sorted(set(agg.get_column("latent_dim").to_list())) # Build lookup keymap: Dict[Tuple[str, int], Cell] = {} keymap_anom: Dict[Tuple[str, int], Cell] = {} max_std: float | None = None def push_std(v: float | None): nonlocal max_std if v is None or not (v == v): return if max_std is None or v > max_std: max_std = v for r in agg.iter_rows(named=True): k = (r["net"], int(r["latent_dim"])) keymap[k] = Cell(r.get("overall_mean"), r.get("overall_std")) keymap_anom[k] = Cell(r.get("anomaly_mean"), r.get("anomaly_std")) push_std(r.get("overall_std")) push_std(r.get("anomaly_std")) # Ensure nets order consistent nets_order = ["LeNet", "Efficient"] nets_present = [n for n in nets_order if any(k[0] == n for k in keymap.keys())] if not nets_present: nets_present = sorted({k[0] for k in keymap.keys()}) # Build LaTeX table header_left = [r"LeNet", r"Efficient"] header_right = [r"LeNet", r"Efficient"] lines: List[str] = [] lines.append(r"\begin{table}[t]") lines.append(r"\centering") lines.append(r"\setlength{\tabcolsep}{4pt}") lines.append(r"\renewcommand{\arraystretch}{1.2}") # vertical bar between the two groups lines.append(r"\begin{tabularx}{\textwidth}{c*{2}{Y}|*{2}{Y}}") lines.append(r"\toprule") lines.append( r" & \multicolumn{2}{c}{Overall loss} & \multicolumn{2}{c}{Anomaly loss} \\" ) lines.append(r"\cmidrule(lr){2-3} \cmidrule(lr){4-5}") lines.append( r"Latent Dim. & " + " & ".join(header_left) + " & " + " & ".join(header_right) + r" \\" ) lines.append(r"\midrule") for d in dims: # Gather values in order: Overall (LeNet, Efficient), Anomaly (LeNet, Efficient) overall_vals = [keymap.get((n, d), Cell(None, None)).mean for n in nets_present] anomaly_vals = [ keymap_anom.get((n, d), Cell(None, None)).mean for n in nets_present ] overall_strs = [_fmt(v) for v in overall_vals] anomaly_strs = [_fmt(v) for v in anomaly_vals] if BOLD_BEST: mask_overall = _bold_mask_display(overall_vals, DECIMALS, LOWER_IS_BETTER) mask_anom = _bold_mask_display(anomaly_vals, DECIMALS, LOWER_IS_BETTER) overall_strs = [ (r"\textbf{" + s + "}") if (m and s != "--") else s for s, m in zip(overall_strs, mask_overall) ] anomaly_strs = [ (r"\textbf{" + s + "}") if (m and s != "--") else s for s, m in zip(anomaly_strs, mask_anom) ] lines.append( f"{d} & " + " & ".join(overall_strs) + " & " + " & ".join(anomaly_strs) + r" \\" ) lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}" lines.append( rf"\caption{{Autoencoder pre-training MSE losses (test split) across latent dimensions. " rf"Left: overall loss; Right: anomaly-only loss. " rf"Cells show means across folds (no $\pm$std). " rf"Maximum observed standard deviation across all cells (not shown): {max_std_str}.}}" ) lines.append(r"\end{table}") return "\n".join(lines), max_std # ---------------------------- # Entry # ---------------------------- def main(): df = load_pretraining_results_dataframe(ROOT, allow_cache=True) # Build LaTeX table tex, max_std = build_losses_table_from_df(df, LABEL_FIELD) # Output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) out_name = "ae_pretraining_losses_table.tex" out_path = ts_dir / out_name out_path.write_text(tex, encoding="utf-8") # Save a copy of this script script_path = Path(__file__) try: shutil.copy2(script_path, ts_dir / script_path.name) except Exception: pass # Mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(parents=True, exist_ok=True) # Clear for f in latest.iterdir(): if f.is_file(): f.unlink() # Copy for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved table to: {ts_dir}") print(f"Also updated: {latest}") print(f" - {out_name}") if __name__ == "__main__": main()