results ae section
This commit is contained in:
306
tools/plot_scripts/results_ae_table.py
Normal file
306
tools/plot_scripts/results_ae_table.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# ae_losses_table_from_df.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||||
from load_results import load_pretraining_results_dataframe
|
||||
|
||||
# ----------------------------
|
||||
# Config
|
||||
# ----------------------------
|
||||
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
|
||||
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_ae_table")
|
||||
|
||||
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
|
||||
LABEL_FIELD = "labels_exp_based"
|
||||
|
||||
# Which architectures to include (labels must match canonicalize_network)
|
||||
WANTED_NETS = {"LeNet", "Efficient"}
|
||||
|
||||
# Formatting
|
||||
DECIMALS = 4 # how many decimals to display for losses
|
||||
BOLD_BEST = False # set True to bold per-group best (lower is better)
|
||||
LOWER_IS_BETTER = True # for losses we want the minimum
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Helpers (ported/minified from your plotting script)
|
||||
# ----------------------------
|
||||
def canonicalize_network(name: str) -> str:
|
||||
low = (name or "").lower()
|
||||
if "lenet" in low:
|
||||
return "LeNet"
|
||||
if "efficient" in low:
|
||||
return "Efficient"
|
||||
return name or "unknown"
|
||||
|
||||
|
||||
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
|
||||
n = len(scores)
|
||||
if n == 0:
|
||||
return np.nan
|
||||
if batch_size <= 0:
|
||||
batch_size = n
|
||||
n_batches = (n + batch_size - 1) // batch_size
|
||||
acc = 0.0
|
||||
for i in range(0, n, batch_size):
|
||||
acc += float(np.mean(scores[i : i + batch_size]))
|
||||
return acc / n_batches
|
||||
|
||||
|
||||
def extract_batch_size(cfg_json: str) -> int:
|
||||
import json
|
||||
|
||||
try:
|
||||
cfg = json.loads(cfg_json) if cfg_json else {}
|
||||
except Exception:
|
||||
cfg = {}
|
||||
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cell:
|
||||
mean: float | None
|
||||
std: float | None
|
||||
|
||||
|
||||
def _fmt(mean: float | None) -> str:
|
||||
return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}"
|
||||
|
||||
|
||||
def _bold_mask_display(
|
||||
values: List[float | None], decimals: int, lower_is_better: bool
|
||||
) -> List[bool]:
|
||||
"""
|
||||
Tie-aware bolding mask based on *displayed* precision.
|
||||
For losses, lower is better (min). For metrics where higher is better, set lower_is_better=False.
|
||||
"""
|
||||
|
||||
def disp(v: float | None) -> float | None:
|
||||
if v is None or not (v == v):
|
||||
return None
|
||||
# use string → float to match display rounding exactly
|
||||
return float(f"{v:.{decimals}f}")
|
||||
|
||||
rounded = [disp(v) for v in values]
|
||||
finite = [v for v in rounded if v is not None]
|
||||
if not finite:
|
||||
return [False] * len(values)
|
||||
target = min(finite) if lower_is_better else max(finite)
|
||||
return [(v is not None and v == target) for v in rounded]
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Core
|
||||
# ----------------------------
|
||||
def build_losses_table_from_df(
|
||||
df: pl.DataFrame, label_field: str
|
||||
) -> Tuple[str, float | None]:
|
||||
"""
|
||||
Build a LaTeX table showing Overall loss (LeNet, Efficient) and Anomaly loss (LeNet, Efficient)
|
||||
with one row per latent dimension. Returns (latex_table_string, max_std_overall).
|
||||
"""
|
||||
# Basic validation
|
||||
required_cols = {"scores", "network", "latent_dim"}
|
||||
missing = required_cols - set(df.columns)
|
||||
if missing:
|
||||
raise ValueError(f"Missing required columns in AE dataframe: {missing}")
|
||||
if label_field not in df.columns:
|
||||
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
|
||||
|
||||
# Canonicalize nets, compute per-row overall/anomaly losses
|
||||
rows: List[dict] = []
|
||||
for row in df.iter_rows(named=True):
|
||||
net = canonicalize_network(row["network"])
|
||||
if WANTED_NETS and net not in WANTED_NETS:
|
||||
continue
|
||||
dim = int(row["latent_dim"])
|
||||
batch_size = extract_batch_size(row.get("config_json"))
|
||||
scores = np.asarray(row["scores"] or [], dtype=float)
|
||||
|
||||
labels = row.get(label_field)
|
||||
labels = np.asarray(labels, dtype=int) if labels is not None else None
|
||||
|
||||
overall_loss = calculate_batch_mean_loss(scores, batch_size)
|
||||
|
||||
anomaly_loss = np.nan
|
||||
if labels is not None and labels.size == scores.size:
|
||||
anomaly_scores = scores[labels == -1]
|
||||
if anomaly_scores.size > 0:
|
||||
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"net": net,
|
||||
"latent_dim": dim,
|
||||
"overall": overall_loss,
|
||||
"anomaly": anomaly_loss,
|
||||
}
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise ValueError(
|
||||
"No rows available after filtering; check WANTED_NETS or input data."
|
||||
)
|
||||
|
||||
df2 = pl.DataFrame(rows)
|
||||
|
||||
# Aggregate across folds per (net, latent_dim)
|
||||
agg = df2.group_by(["net", "latent_dim"]).agg(
|
||||
pl.col("overall").mean().alias("overall_mean"),
|
||||
pl.col("overall").std().alias("overall_std"),
|
||||
pl.col("anomaly").mean().alias("anomaly_mean"),
|
||||
pl.col("anomaly").std().alias("anomaly_std"),
|
||||
)
|
||||
|
||||
# Collect union of dims across both nets
|
||||
dims = sorted(set(agg.get_column("latent_dim").to_list()))
|
||||
|
||||
# Build lookup
|
||||
keymap: Dict[Tuple[str, int], Cell] = {}
|
||||
keymap_anom: Dict[Tuple[str, int], Cell] = {}
|
||||
|
||||
max_std: float | None = None
|
||||
|
||||
def push_std(v: float | None):
|
||||
nonlocal max_std
|
||||
if v is None or not (v == v):
|
||||
return
|
||||
if max_std is None or v > max_std:
|
||||
max_std = v
|
||||
|
||||
for r in agg.iter_rows(named=True):
|
||||
k = (r["net"], int(r["latent_dim"]))
|
||||
keymap[k] = Cell(r.get("overall_mean"), r.get("overall_std"))
|
||||
keymap_anom[k] = Cell(r.get("anomaly_mean"), r.get("anomaly_std"))
|
||||
push_std(r.get("overall_std"))
|
||||
push_std(r.get("anomaly_std"))
|
||||
|
||||
# Ensure nets order consistent
|
||||
nets_order = ["LeNet", "Efficient"]
|
||||
nets_present = [n for n in nets_order if any(k[0] == n for k in keymap.keys())]
|
||||
if not nets_present:
|
||||
nets_present = sorted({k[0] for k in keymap.keys()})
|
||||
|
||||
# Build LaTeX table
|
||||
header_left = [r"LeNet", r"Efficient"]
|
||||
header_right = [r"LeNet", r"Efficient"]
|
||||
|
||||
lines: List[str] = []
|
||||
lines.append(r"\begin{table}[t]")
|
||||
lines.append(r"\centering")
|
||||
lines.append(r"\setlength{\tabcolsep}{4pt}")
|
||||
lines.append(r"\renewcommand{\arraystretch}{1.2}")
|
||||
# vertical bar between the two groups
|
||||
lines.append(r"\begin{tabularx}{\textwidth}{c*{2}{Y}|*{2}{Y}}")
|
||||
lines.append(r"\toprule")
|
||||
lines.append(
|
||||
r" & \multicolumn{2}{c}{Overall loss} & \multicolumn{2}{c}{Anomaly loss} \\"
|
||||
)
|
||||
lines.append(r"\cmidrule(lr){2-3} \cmidrule(lr){4-5}")
|
||||
lines.append(
|
||||
r"Latent Dim. & "
|
||||
+ " & ".join(header_left)
|
||||
+ " & "
|
||||
+ " & ".join(header_right)
|
||||
+ r" \\"
|
||||
)
|
||||
lines.append(r"\midrule")
|
||||
|
||||
for d in dims:
|
||||
# Gather values in order: Overall (LeNet, Efficient), Anomaly (LeNet, Efficient)
|
||||
overall_vals = [keymap.get((n, d), Cell(None, None)).mean for n in nets_present]
|
||||
anomaly_vals = [
|
||||
keymap_anom.get((n, d), Cell(None, None)).mean for n in nets_present
|
||||
]
|
||||
overall_strs = [_fmt(v) for v in overall_vals]
|
||||
anomaly_strs = [_fmt(v) for v in anomaly_vals]
|
||||
|
||||
if BOLD_BEST:
|
||||
mask_overall = _bold_mask_display(overall_vals, DECIMALS, LOWER_IS_BETTER)
|
||||
mask_anom = _bold_mask_display(anomaly_vals, DECIMALS, LOWER_IS_BETTER)
|
||||
overall_strs = [
|
||||
(r"\textbf{" + s + "}") if (m and s != "--") else s
|
||||
for s, m in zip(overall_strs, mask_overall)
|
||||
]
|
||||
anomaly_strs = [
|
||||
(r"\textbf{" + s + "}") if (m and s != "--") else s
|
||||
for s, m in zip(anomaly_strs, mask_anom)
|
||||
]
|
||||
|
||||
lines.append(
|
||||
f"{d} & "
|
||||
+ " & ".join(overall_strs)
|
||||
+ " & "
|
||||
+ " & ".join(anomaly_strs)
|
||||
+ r" \\"
|
||||
)
|
||||
|
||||
lines.append(r"\bottomrule")
|
||||
lines.append(r"\end{tabularx}")
|
||||
|
||||
max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}"
|
||||
lines.append(
|
||||
rf"\caption{{Autoencoder pre-training MSE losses (test split) across latent dimensions. "
|
||||
rf"Left: overall loss; Right: anomaly-only loss. "
|
||||
rf"Cells show means across folds (no $\pm$std). "
|
||||
rf"Maximum observed standard deviation across all cells (not shown): {max_std_str}.}}"
|
||||
)
|
||||
lines.append(r"\end{table}")
|
||||
|
||||
return "\n".join(lines), max_std
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Entry
|
||||
# ----------------------------
|
||||
def main():
|
||||
df = load_pretraining_results_dataframe(ROOT, allow_cache=True)
|
||||
|
||||
# Build LaTeX table
|
||||
tex, max_std = build_losses_table_from_df(df, LABEL_FIELD)
|
||||
|
||||
# Output dirs
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
out_name = "ae_pretraining_losses_table.tex"
|
||||
out_path = ts_dir / out_name
|
||||
out_path.write_text(tex, encoding="utf-8")
|
||||
|
||||
# Save a copy of this script
|
||||
script_path = Path(__file__)
|
||||
try:
|
||||
shutil.copy2(script_path, ts_dir / script_path.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Mirror latest
|
||||
latest = OUTPUT_DIR / "latest"
|
||||
latest.mkdir(parents=True, exist_ok=True)
|
||||
# Clear
|
||||
for f in latest.iterdir():
|
||||
if f.is_file():
|
||||
f.unlink()
|
||||
# Copy
|
||||
for f in ts_dir.iterdir():
|
||||
if f.is_file():
|
||||
shutil.copy2(f, latest / f.name)
|
||||
|
||||
print(f"Saved table to: {ts_dir}")
|
||||
print(f"Also updated: {latest}")
|
||||
print(f" - {out_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user