From 95867bde7a0e71e0f6b904c3e3228c0e34521519 Mon Sep 17 00:00:00 2001 From: Jan Kowalczyk Date: Wed, 17 Sep 2025 11:07:07 +0200 Subject: [PATCH] table plot --- .../results_latent_space_tables.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 tools/plot_scripts/results_latent_space_tables.py diff --git a/tools/plot_scripts/results_latent_space_tables.py b/tools/plot_scripts/results_latent_space_tables.py new file mode 100644 index 0000000..4fff010 --- /dev/null +++ b/tools/plot_scripts/results_latent_space_tables.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import shutil +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import polars as pl + +# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY +from load_results import load_results_dataframe + +# ---------------------------- +# Config +# ---------------------------- +ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader +OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") + +# Semi-labeling regimes (semi_normals, semi_anomalous) +SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] + +# Which evaluation columns to include (one table per eval × semi-regime) +EVALS: list[str] = ["exp_based", "manual_based"] + +# Row order (latent dims) +LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] + +# Column order (method shown to the user) +# We split DeepSAD into the two network backbones, like your plots. +METHOD_COLUMNS = [ + ("deepsad", "LeNet"), # DeepSAD (LeNet) + ("deepsad", "Efficient"), # DeepSAD (Efficient) + ("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline) + ("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline) +] + +# Formatting +DECIMALS = 3 # number of decimals for mean/std +STD_FMT = r"\textpm" # between mean and std in LaTeX + + +# ---------------------------- +# Helpers +# ---------------------------- +def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: + """Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback).""" + return df.with_columns( + pl.when( + pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") + ) + .then(pl.lit("LeNet")) + .when( + pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") + ) + .then(pl.lit("Efficient")) + .otherwise(pl.col("network").cast(pl.Utf8)) + .alias("net_label") + ) + + +def _filter_base( + df: pl.DataFrame, + *, + eval_type: str, + semi_normals: int, + semi_anomalous: int, +) -> pl.DataFrame: + """Common filtering by regime/eval/valid dims&models.""" + return df.filter( + (pl.col("semi_normals") == semi_normals) + & (pl.col("semi_anomalous") == semi_anomalous) + & (pl.col("eval") == eval_type) + & (pl.col("latent_dim").is_in(LATENT_DIMS)) + & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) + ).select( + "model", + "net_label", + "latent_dim", + "fold", + "auc", + ) + + +def _format_mean_std(mean: float | None, std: float | None) -> str: + if mean is None or (mean != mean): # NaN check + return "--" + if std is None or (std != std): + return f"{mean:.{DECIMALS}f}" + return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}" + + +@dataclass(frozen=True) +class Cell: + mean: float | None + std: float | None + + +def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]: + """ + Compute per-(latent_dim, model, net_label) mean/std for AUC across folds. + Returns a dict keyed by (latent_dim, model, net_label). + """ + if df.is_empty(): + return {} + agg = ( + df.group_by(["latent_dim", "model", "net_label"]) + .agg( + pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc") + ) + .to_dicts() + ) + out: dict[tuple[int, str, str], Cell] = {} + for row in agg: + key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"])) + out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc")) + return out + + +def _bold_best_in_row(values: list[float | None]) -> list[bool]: + """Return a mask of which entries are the (tied) maximum among non-None values.""" + clean = [(v if (v is not None and v == v) else None) for v in values] + finite_vals = [v for v in clean if v is not None] + if not finite_vals: + return [False] * len(values) + maxv = max(finite_vals) + return [(v is not None and abs(v - maxv) < 1e-12) for v in clean] + + +def _latex_table( + cells: dict[tuple[int, str, str], Cell], + *, + eval_type: str, + semi_normals: int, + semi_anomalous: int, +) -> str: + """ + Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS. + Bold best AUC (mean) per row. + """ + header_cols = [ + r"\textbf{DeepSAD (LeNet)}", + r"\textbf{DeepSAD (Efficient)}", + r"\textbf{IsolationForest}", + r"\textbf{OC\text{-}SVM}", + ] + + eval_type_str = ( + "experiment-based evaluation" + if eval_type == "exp_based" + else "handlabeling-based evaluation" + ) + + lines: list[str] = [] + lines.append(r"\begin{table}[t]") + lines.append(r"\centering") + lines.append( + rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, " + rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}" + ) + lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}") + lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}") + lines.append(r"\toprule") + lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\") + lines.append(r"\midrule") + + for dim in LATENT_DIMS: + # Collect means for bolding + means_for_bold: list[float | None] = [] + cell_strs: list[str] = [] + for model, net in METHOD_COLUMNS: + cell = cells.get((dim, model, net), Cell(None, None)) + means_for_bold.append(cell.mean) + cell_strs.append(_format_mean_std(cell.mean, cell.std)) + + bold_mask = _bold_best_in_row(means_for_bold) + pretty_cells: list[str] = [] + for s, do_bold in zip(cell_strs, bold_mask): + if do_bold and s != "--": + pretty_cells.append(r"\textbf{" + s + r"}") + else: + pretty_cells.append(s) + + lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\") + + lines.append(r"\bottomrule") + lines.append(r"\end{tabularx}") + lines.append(r"\end{table}") + return "\n".join(lines) + + +def main(): + # Load full results DF (cache behavior handled by your loader) + df = load_results_dataframe(ROOT, allow_cache=True) + df = _with_net_label(df) + + # Prepare output dirs + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + archive_dir = OUTPUT_DIR / "archive" + archive_dir.mkdir(parents=True, exist_ok=True) + ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ts_dir.mkdir(parents=True, exist_ok=True) + + emitted_files: list[Path] = [] + + for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES: + for eval_type in EVALS: + sub = _filter_base( + df, + eval_type=eval_type, + semi_normals=semi_normals, + semi_anomalous=semi_anomalous, + ) + + # For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots + sub = sub.filter( + pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) + .then(pl.col("net_label") == "Efficient") + .otherwise(True) + ) + + cells = _compute_cells(sub) + tex = _latex_table( + cells, + eval_type=eval_type, + semi_normals=semi_normals, + semi_anomalous=semi_anomalous, + ) + + out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex" + out_path = ts_dir / out_name + out_path.write_text(tex, encoding="utf-8") + emitted_files.append(out_path) + + # Copy this script to preserve the code used for the outputs + script_path = Path(__file__) + shutil.copy2(script_path, ts_dir / script_path.name) + + # Mirror latest + latest = OUTPUT_DIR / "latest" + latest.mkdir(exist_ok=True, parents=True) + for f in latest.iterdir(): + if f.is_file(): + f.unlink() + for f in ts_dir.iterdir(): + if f.is_file(): + shutil.copy2(f, latest / f.name) + + print(f"Saved tables to: {ts_dir}") + print(f"Also updated: {latest}") + for p in emitted_files: + print(f" - {p.name}") + + +if __name__ == "__main__": + main()