from __future__ import annotations import shutil from dataclasses import dataclass from datetime import datetime from pathlib import Path import polars as pl # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") # Semi-labeling regimes (semi_normals, semi_anomalous) SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] # Which evaluation columns to include (one table per eval × semi-regime) EVALS: list[str] = ["exp_based", "manual_based"] # Row order (latent dims) LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] # Column order (method shown to the user) # We split DeepSAD into the two network backbones, like your plots. METHOD_COLUMNS = [ ("deepsad", "LeNet"), # DeepSAD (LeNet) ("deepsad", "Efficient"), # DeepSAD (Efficient) ("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline) ("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline) ] # Formatting DECIMALS = 3 # number of decimals for mean/std STD_FMT = r"\textpm" # between mean and std in LaTeX # ---------------------------- # Helpers # ---------------------------- def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: """Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback).""" return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def _filter_base( df: pl.DataFrame, *, eval_type: str, semi_normals: int, semi_anomalous: int, ) -> pl.DataFrame: """Common filtering by regime/eval/valid dims&models.""" return df.filter( (pl.col("semi_normals") == semi_normals) & (pl.col("semi_anomalous") == semi_anomalous) & (pl.col("eval") == eval_type) & (pl.col("latent_dim").is_in(LATENT_DIMS)) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) ).select( "model", "net_label", "latent_dim", "fold", "auc", ) def _format_mean_std(mean: float | None, std: float | None) -> str: if mean is None or (mean != mean): # NaN check return "--" if std is None or (std != std): return f"{mean:.{DECIMALS}f}" return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}" @dataclass(frozen=True) class Cell: mean: float | None std: float | None def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]: """ Compute per-(latent_dim, model, net_label) mean/std for AUC across folds. Returns a dict keyed by (latent_dim, model, net_label). """ if df.is_empty(): return {} agg = ( df.group_by(["latent_dim", "model", "net_label"]) .agg( pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc") ) .to_dicts() ) out: dict[tuple[int, str, str], Cell] = {} for row in agg: key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"])) out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc")) return out def _bold_best_in_row(values: list[float | None]) -> list[bool]: """Return a mask of which entries are the (tied) maximum among non-None values.""" clean = [(v if (v is not None and v == v) else None) for v in values] finite_vals = [v for v in clean if v is not None] if not finite_vals: return [False] * len(values) maxv = max(finite_vals) return [(v is not None and abs(v - maxv) < 1e-12) for v in clean] def _latex_table( cells: dict[tuple[int, str, str], Cell], *, eval_type: str, semi_normals: int, semi_anomalous: int, ) -> str: """ Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS. Bold best AUC (mean) per row. """ header_cols = [ r"\textbf{DeepSAD (LeNet)}", r"\textbf{DeepSAD (Efficient)}", r"\textbf{IsolationForest}", r"\textbf{OC\text{-}SVM}", ] eval_type_str = ( "experiment-based evaluation" if eval_type == "exp_based" else "handlabeling-based evaluation" ) lines: list[str] = [] lines.append(r"\begin{table}[t]") lines.append(r"\centering") lines.append( rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, " rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}" ) lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}") lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}") lines.append(r"\toprule") lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\") lines.append(r"\midrule") for dim in LATENT_DIMS: # Collect means for bolding means_for_bold: list[float | None] = [] cell_strs: list[str] = [] for model, net in METHOD_COLUMNS: cell = cells.get((dim, model, net), Cell(None, None)) means_for_bold.append(cell.mean) cell_strs.append(_format_mean_std(cell.mean, cell.std)) bold_mask = _bold_best_in_row(means_for_bold) pretty_cells: list[str] = [] for s, do_bold in zip(cell_strs, bold_mask): if do_bold and s != "--": pretty_cells.append(r"\textbf{" + s + r"}") else: pretty_cells.append(s) lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\") lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") lines.append(r"\end{table}") return "\n".join(lines) def main(): # Load full results DF (cache behavior handled by your loader) df = load_results_dataframe(ROOT, allow_cache=True) df = _with_net_label(df) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive_dir = OUTPUT_DIR / "archive" archive_dir.mkdir(parents=True, exist_ok=True) ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) emitted_files: list[Path] = [] for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES: for eval_type in EVALS: sub = _filter_base( df, eval_type=eval_type, semi_normals=semi_normals, semi_anomalous=semi_anomalous, ) # For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots sub = sub.filter( pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) .then(pl.col("net_label") == "Efficient") .otherwise(True) ) cells = _compute_cells(sub) tex = _latex_table( cells, eval_type=eval_type, semi_normals=semi_normals, semi_anomalous=semi_anomalous, ) out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex" out_path = ts_dir / out_name out_path.write_text(tex, encoding="utf-8") emitted_files.append(out_path) # Copy this script to preserve the code used for the outputs script_path = Path(__file__) shutil.copy2(script_path, ts_dir / script_path.name) # Mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(exist_ok=True, parents=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved tables to: {ts_dir}") print(f"Also updated: {latest}") for p in emitted_files: print(f" - {p.name}") if __name__ == "__main__": main()