from __future__ import annotations import shutil from dataclasses import dataclass from datetime import datetime from pathlib import Path import polars as pl # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") # Semi-labeling regimes (semi_normals, semi_anomalous) in display order SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] # Both evals are shown side-by-side in one table EVALS_BOTH: tuple[str, str] = ("exp_based", "manual_based") # Row order (latent dims) LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] # Column order (method shown to the user) # We split DeepSAD into the two network backbones, like your plots. METHOD_COLUMNS = [ ("deepsad", "LeNet"), # DeepSAD (LeNet) ("deepsad", "Efficient"), # DeepSAD (Efficient) ("isoforest", "Efficient"), # IsolationForest (Efficient baseline) ("ocsvm", "Efficient"), # OC-SVM (Efficient baseline) ] # Formatting DECIMALS = 3 # cells look like 1.000 or 0.928 (3 decimals) # ---------------------------- # Helpers # ---------------------------- def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: """Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback).""" return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def _filter_base(df: pl.DataFrame) -> pl.DataFrame: """Restrict to valid dims/models and needed columns (no eval/regime filtering here).""" return df.filter( (pl.col("latent_dim").is_in(LATENT_DIMS)) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) & (pl.col("eval").is_in(list(EVALS_BOTH))) ).select( "model", "net_label", "latent_dim", "fold", "auc", "eval", "semi_normals", "semi_anomalous", ) @dataclass(frozen=True) class Cell: mean: float | None std: float | None def _compute_cells(df: pl.DataFrame) -> dict[tuple[str, int, str, str, int, int], Cell]: """ Compute per-(eval, latent_dim, model, net_label, semi_normals, semi_anomalous) mean/std for AUC across folds. """ if df.is_empty(): return {} # For baselines (isoforest/ocsvm) constrain to Efficient backbone df = df.filter( pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) .then(pl.col("net_label") == "Efficient") .otherwise(True) ) agg = ( df.group_by( [ "eval", "latent_dim", "model", "net_label", "semi_normals", "semi_anomalous", ] ) .agg( pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc") ) .to_dicts() ) out: dict[tuple[str, int, str, str, int, int], Cell] = {} for row in agg: key = ( str(row["eval"]), int(row["latent_dim"]), str(row["model"]), str(row["net_label"]), int(row["semi_normals"]), int(row["semi_anomalous"]), ) out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc")) return out def _fmt_mean(mean: float | None) -> str: return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}" def _bold_best_mask_display(values: list[float | None], decimals: int) -> list[bool]: """ Bolding mask based on *displayed* precision. Any entries that round (via f-string) to the maximum at 'decimals' places are bolded (ties bolded). """ def disp(v: float | None) -> float | None: if v is None or not (v == v): return None return float(f"{v:.{decimals}f}") rounded = [disp(v) for v in values] finite = [v for v in rounded if v is not None] if not finite: return [False] * len(values) maxv = max(finite) return [(v is not None and v == maxv) for v in rounded] def _build_single_table( cells: dict[tuple[str, int, str, str, int, int], Cell], *, semi_labeling_regimes: list[tuple[int, int]], ) -> tuple[str, float | None]: """ Build the LaTeX table string with grouped headers and regime blocks. Returns (latex, max_std_overall). """ # Rotated header labels (90° slanted) header_cols = [ r"\rotheader{DeepSAD\\(LeNet)}", r"\rotheader{DeepSAD\\(Efficient)}", r"\rotheader{IsoForest}", r"\rotheader{OC-SVM}", ] # Track max std across all cells max_std: float | None = None def push_std(std_val: float | None): nonlocal max_std if std_val is None or not (std_val == std_val): return if max_std is None or std_val > max_std: max_std = std_val lines: list[str] = [] # Table preamble / structure lines.append(r"\begin{table}[t]") lines.append(r"\centering") lines.append(r"\setlength{\tabcolsep}{4pt}") lines.append(r"\renewcommand{\arraystretch}{1.2}") # Vertical rule between the two groups for data/header rows: lines.append(r"\begin{tabularx}{\textwidth}{c*{4}{Y}|*{4}{Y}}") lines.append(r"\toprule") lines.append( r" & \multicolumn{4}{c}{Experiment-based eval.} & \multicolumn{4}{c}{Handlabeled eval.} \\" ) lines.append(r"\cmidrule(lr){2-5} \cmidrule(lr){6-9}") lines.append( r"Latent Dim. & " + " & ".join(header_cols) + " & " + " & ".join(header_cols) + r" \\" ) lines.append(r"\midrule") # Iterate regimes and rows for idx, (semi_n, semi_a) in enumerate(semi_labeling_regimes): # Regime label row (multicolumn suppresses the vertical bar in this row) lines.append( rf"\multicolumn{{9}}{{l}}{{\textbf{{Labeling regime: }}\(\mathbf{{{semi_n}/{semi_a}}}\) " rf"\textit{{(normal/anomalous samples labeled)}}}} \\" ) lines.append(r"\addlinespace[2pt]") for dim in LATENT_DIMS: # Values in order: left group (exp_based) 4 cols, right group (manual_based) 4 cols means_left: list[float | None] = [] means_right: list[float | None] = [] cell_strs_left: list[str] = [] cell_strs_right: list[str] = [] # Left group: exp_based eval_type = EVALS_BOTH[0] for model, net in METHOD_COLUMNS: key = (eval_type, dim, model, net, semi_n, semi_a) cell = cells.get(key, Cell(None, None)) means_left.append(cell.mean) cell_strs_left.append(_fmt_mean(cell.mean)) push_std(cell.std) # Right group: manual_based eval_type = EVALS_BOTH[1] for model, net in METHOD_COLUMNS: key = (eval_type, dim, model, net, semi_n, semi_a) cell = cells.get(key, Cell(None, None)) means_right.append(cell.mean) cell_strs_right.append(_fmt_mean(cell.mean)) push_std(cell.std) # Bolding per group based on displayed precision mask_left = _bold_best_mask_display(means_left, DECIMALS) mask_right = _bold_best_mask_display(means_right, DECIMALS) pretty_left = [ (r"\textbf{" + s + "}") if (do_bold and s != "--") else s for s, do_bold in zip(cell_strs_left, mask_left) ] pretty_right = [ (r"\textbf{" + s + "}") if (do_bold and s != "--") else s for s, do_bold in zip(cell_strs_right, mask_right) ] # Join with the vertical bar between groups automatically handled by column spec lines.append( f"{dim} & " + " & ".join(pretty_left) + " & " + " & ".join(pretty_right) + r" \\" ) # Separator between regime blocks (but not after the last one) if idx < len(semi_labeling_regimes) - 1: lines.append(r"\midrule") lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") # Caption with max std (not shown in table) max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}" lines.append( rf"\caption{{AUC means across 5 folds for both evaluations, grouped by labeling regime. " rf"Maximum observed standard deviation across all cells (not shown in table): {max_std_str}.}}" ) lines.append(r"\end{table}") return "\n".join(lines), max_std def main(): # Load full results DF (cache behavior handled by your loader) df = load_results_dataframe(ROOT, allow_cache=True) df = _with_net_label(df) df = _filter_base(df) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive_dir = OUTPUT_DIR / "archive" archive_dir.mkdir(parents=True, exist_ok=True) ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) # Pre-compute aggregated cells (mean/std) for all evals/regimes cells = _compute_cells(df) # Build the single big table tex, max_std = _build_single_table( cells, semi_labeling_regimes=SEMI_LABELING_REGIMES ) out_name = "auc_table_all_evals_all_regimes.tex" out_path = ts_dir / out_name out_path.write_text(tex, encoding="utf-8") # Copy this script to preserve the code used for the outputs script_path = Path(__file__) shutil.copy2(script_path, ts_dir / script_path.name) # Mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(exist_ok=True, parents=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved table to: {ts_dir}") print(f"Also updated: {latest}") print(f" - {out_name}") if __name__ == "__main__": main()