from __future__ import annotations import shutil from dataclasses import dataclass from datetime import datetime from pathlib import Path import numpy as np import polars as pl # CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY from load_results import load_results_dataframe # ---------------------------- # Config # ---------------------------- ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") # Semi-labeling regimes (semi_normals, semi_anomalous) in display order SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] # Both evals are shown side-by-side in one table EVALS_BOTH: tuple[str, str] = ("exp_based", "manual_based") # Row order (latent dims) LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] # Column order (method shown to the user) # We split DeepSAD into the two network backbones, like your plots. METHOD_COLUMNS = [ ("deepsad", "LeNet"), # DeepSAD (LeNet) ("deepsad", "Efficient"), # DeepSAD (Efficient) ("isoforest", "Efficient"), # IsolationForest (Efficient baseline) ("ocsvm", "Efficient"), # OC-SVM (Efficient baseline) ] # Formatting DECIMALS = 3 # cells look like 1.000 or 0.928 (3 decimals) # ---------------------------- # Helpers # ---------------------------- def _fmt_mean_std(mean: float | None, std: float | None) -> str: """Format mean ± std with 3 decimals (leading zero), or '--' if missing.""" if mean is None or not (mean == mean): # NaN check return "--" if std is None or not (std == std): return f"{mean:.3f}" return f"{mean:.3f}$\\,\\pm\\,{std:.3f}$" def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: """Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback).""" return df.with_columns( pl.when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet") ) .then(pl.lit("LeNet")) .when( pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient") ) .then(pl.lit("Efficient")) .otherwise(pl.col("network").cast(pl.Utf8)) .alias("net_label") ) def _filter_base(df: pl.DataFrame) -> pl.DataFrame: """Restrict to valid dims/models and needed columns (no eval/regime filtering here).""" return df.filter( (pl.col("latent_dim").is_in(LATENT_DIMS)) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) & (pl.col("eval").is_in(list(EVALS_BOTH))) ).select( "model", "net_label", "latent_dim", "fold", "ap", "eval", "semi_normals", "semi_anomalous", ) @dataclass(frozen=True) class Cell: mean: float | None std: float | None def _compute_cells(df: pl.DataFrame) -> dict[tuple[str, int, str, str, int, int], Cell]: """ Compute per-(eval, latent_dim, model, net_label, semi_normals, semi_anomalous) mean/std for AP across folds. """ if df.is_empty(): return {} # For baselines (isoforest/ocsvm) constrain to Efficient backbone df = df.filter( pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) .then(pl.col("net_label") == "Efficient") .otherwise(True) ) agg = ( df.group_by( [ "eval", "latent_dim", "model", "net_label", "semi_normals", "semi_anomalous", ] ) .agg(pl.col("ap").mean().alias("mean_ap"), pl.col("ap").std().alias("std_ap")) .to_dicts() ) out: dict[tuple[str, int, str, str, int, int], Cell] = {} for row in agg: key = ( str(row["eval"]), int(row["latent_dim"]), str(row["model"]), str(row["net_label"]), int(row["semi_normals"]), int(row["semi_anomalous"]), ) out[key] = Cell(mean=row.get("mean_ap"), std=row.get("std_ap")) return out def method_label(model: str, net_label: str) -> str: """Map (model, net_label) to the four method names used in headers/caption.""" if model == "deepsad" and net_label == "LeNet": return "DeepSAD (LeNet)" if model == "deepsad" and net_label == "Efficient": return "DeepSAD (Efficient)" if model == "isoforest": return "IsoForest" if model == "ocsvm": return "OC-SVM" # ignore anything else (e.g., other backbones) return "" def per_method_median_std_from_cells( cells: dict[tuple[str, int, str, str, int, int], Cell], ) -> dict[str, float]: """Compute the median std across all cells, per method.""" stds_by_method: dict[str, list[float]] = { "DeepSAD (LeNet)": [], "DeepSAD (Efficient)": [], "IsoForest": [], "OC-SVM": [], } for key, cell in cells.items(): (ev, dim, model, net, semi_n, semi_a) = key name = method_label(model, net) if name and (cell.std is not None) and (cell.std == cell.std): # not NaN stds_by_method[name].append(cell.std) return { name: float(np.median(vals)) if vals else float("nan") for name, vals in stds_by_method.items() } def per_method_max_std_from_cells( cells: dict[tuple[str, int, str, str, int, int], Cell], ) -> tuple[dict[str, float], dict[str, tuple]]: """ Scan the aggregated 'cells' and return: - max_std_by_method: dict {"DeepSAD (LeNet)": 0.037, ...} - argmax_key_by_method: which cell (eval, dim, model, net, semi_n, semi_a) produced that max Only considers the four methods shown in the table. """ max_std_by_method: dict[str, float] = { "DeepSAD (LeNet)": float("nan"), "DeepSAD (Efficient)": float("nan"), "IsoForest": float("nan"), "OC-SVM": float("nan"), } argmax_key_by_method: dict[str, tuple] = {} for key, cell in cells.items(): (ev, dim, model, net, semi_n, semi_a) = key name = method_label(model, net) if name == "" or cell.std is None or not (cell.std == cell.std): # empty/NaN continue cur = max_std_by_method.get(name, float("nan")) if (cur != cur) or (cell.std > cur): # handle NaN initial max_std_by_method[name] = cell.std argmax_key_by_method[name] = key # Replace remaining NaNs with 0.0 for nice formatting for k, v in list(max_std_by_method.items()): if not (v == v): # NaN max_std_by_method[k] = 0.0 return max_std_by_method, argmax_key_by_method def _fmt_val(val: float | None) -> str: """ Format value as: - '--' if None/NaN - '1.0' if exactly 1 (within 1e-9) - '.xx' otherwise (2 decimals, no leading 0) """ if val is None or not (val == val): # None or NaN return "--" if abs(val - 1.0) < 1e-9: return "1.0" return f"{val:.2f}".lstrip("0") def _fmt_mean(mean: float | None) -> str: return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}" def _bold_best_mask_display(values: list[float | None], decimals: int) -> list[bool]: """ Bolding mask based on *displayed* precision. Any entries that round (via f-string) to the maximum at 'decimals' places are bolded (ties bolded). """ def disp(v: float | None) -> float | None: if v is None or not (v == v): return None return float(f"{v:.{decimals}f}") rounded = [disp(v) for v in values] finite = [v for v in rounded if v is not None] if not finite: return [False] * len(values) maxv = max(finite) return [(v is not None and v == maxv) for v in rounded] def _build_exp_based_table( cells: dict[tuple[str, int, str, str, int, int], Cell], *, semi_labeling_regimes: list[tuple[int, int]], ) -> str: """ Build LaTeX table with mean ± std values for experiment-based evaluation only. """ header_cols = [ r"\rotheader{DeepSAD\\(LeNet)}", r"\rotheader{DeepSAD\\(Efficient)}", r"\rotheader{IsoForest}", r"\rotheader{OC-SVM}", ] lines: list[str] = [] lines.append(r"\begin{table}[t]") lines.append(r"\centering") lines.append(r"\setlength{\tabcolsep}{4pt}") lines.append(r"\renewcommand{\arraystretch}{1.2}") lines.append(r"\begin{tabularx}{\textwidth}{c*{4}{Y}}") lines.append(r"\toprule") lines.append(r"Latent Dim. & " + " & ".join(header_cols) + r" \\") lines.append(r"\midrule") for idx, (semi_n, semi_a) in enumerate(semi_labeling_regimes): # regime label row lines.append( rf"\multicolumn{{5}}{{l}}{{\textbf{{Labeling regime: }}\(\mathbf{{{semi_n}/{semi_a}}}\)}} \\" ) lines.append(r"\addlinespace[2pt]") for dim in LATENT_DIMS: row_vals = [] for model, net in METHOD_COLUMNS: key = ("exp_based", dim, model, net, semi_n, semi_a) cell = cells.get(key, Cell(None, None)) row_vals.append(_fmt_mean_std(cell.mean, cell.std)) lines.append(f"{dim} & " + " & ".join(row_vals) + r" \\") if idx < len(semi_labeling_regimes) - 1: lines.append(r"\midrule") lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") lines.append( r"\caption{AP means $\pm$ std across 5 folds for experiment-based evaluation only, grouped by labeling regime.}" ) lines.append(r"\end{table}") return "\n".join(lines) def _build_single_table( cells: dict[tuple[str, int, str, str, int, int], Cell], *, semi_labeling_regimes: list[tuple[int, int]], ) -> tuple[str, float | None]: """ Build the LaTeX table string with grouped headers and regime blocks. Returns (latex, max_std_overall). """ # Rotated header labels (90° slanted) header_cols = [ r"\rotheader{DeepSAD\\(LeNet)}", r"\rotheader{DeepSAD\\(Efficient)}", r"\rotheader{IsoForest}", r"\rotheader{OC-SVM}", ] # Track max std across all cells max_std: float | None = None def push_std(std_val: float | None): nonlocal max_std if std_val is None or not (std_val == std_val): return if max_std is None or std_val > max_std: max_std = std_val lines: list[str] = [] # Table preamble / structure lines.append(r"\begin{table}[t]") lines.append(r"\centering") lines.append(r"\setlength{\tabcolsep}{4pt}") lines.append(r"\renewcommand{\arraystretch}{1.2}") # Vertical rule between the two groups for data/header rows: lines.append(r"\begin{tabularx}{\textwidth}{c*{4}{Y}|*{4}{Y}}") lines.append(r"\toprule") lines.append( r" & \multicolumn{4}{c}{Experiment-based eval.} & \multicolumn{4}{c}{Handlabeled eval.} \\" ) lines.append(r"\cmidrule(lr){2-5} \cmidrule(lr){6-9}") lines.append( r"Latent Dim. & " + " & ".join(header_cols) + " & " + " & ".join(header_cols) + r" \\" ) lines.append(r"\midrule") # Iterate regimes and rows for idx, (semi_n, semi_a) in enumerate(semi_labeling_regimes): # Regime label row (multicolumn suppresses the vertical bar in this row) lines.append( rf"\multicolumn{{9}}{{l}}{{\textbf{{Labeling regime: }}\(\mathbf{{{semi_n}/{semi_a}}}\) " rf"\textit{{(normal/anomalous samples labeled)}}}} \\" ) lines.append(r"\addlinespace[2pt]") for dim in LATENT_DIMS: # Values in order: left group (exp_based) 4 cols, right group (manual_based) 4 cols means_left: list[float | None] = [] means_right: list[float | None] = [] cell_strs_left: list[str] = [] cell_strs_right: list[str] = [] # Left group: exp_based eval_type = EVALS_BOTH[0] for model, net in METHOD_COLUMNS: key = (eval_type, dim, model, net, semi_n, semi_a) cell = cells.get(key, Cell(None, None)) means_left.append(cell.mean) cell_strs_left.append(_fmt_mean(cell.mean)) # mean_str = _fmt_val(cell.mean) # std_str = _fmt_val(cell.std) # if mean_str == "--": # cell_strs_left.append("--") # else: # cell_strs_left.append(f"{mean_str} $\\textpm$ {std_str}") push_std(cell.std) # Right group: manual_based eval_type = EVALS_BOTH[1] for model, net in METHOD_COLUMNS: key = (eval_type, dim, model, net, semi_n, semi_a) cell = cells.get(key, Cell(None, None)) means_right.append(cell.mean) cell_strs_right.append(_fmt_mean(cell.mean)) # mean_str = _fmt_val(cell.mean) # std_str = _fmt_val(cell.std) # if mean_str == "--": # cell_strs_right.append("--") # else: # cell_strs_right.append(f"{mean_str} $\\textpm$ {std_str}") push_std(cell.std) # Bolding per group based on displayed precision mask_left = _bold_best_mask_display(means_left, DECIMALS) mask_right = _bold_best_mask_display(means_right, DECIMALS) pretty_left = [ (r"\textbf{" + s + "}") if (do_bold and s != "--") else s for s, do_bold in zip(cell_strs_left, mask_left) ] pretty_right = [ (r"\textbf{" + s + "}") if (do_bold and s != "--") else s for s, do_bold in zip(cell_strs_right, mask_right) ] # Join with the vertical bar between groups automatically handled by column spec lines.append( f"{dim} & " + " & ".join(pretty_left) + " & " + " & ".join(pretty_right) + r" \\" ) # Separator between regime blocks (but not after the last one) if idx < len(semi_labeling_regimes) - 1: lines.append(r"\midrule") lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") # Compute per-method max std across everything included in the table # max_std_by_method, argmax_key = per_method_max_std_from_cells(cells) median_std_by_method = per_method_median_std_from_cells(cells) # Optional: print where each max came from (helps verify) for name, v in median_std_by_method.items(): print(f"[max-std] {name}: {v:.3f}") cap_parts = [] for name in ["DeepSAD (LeNet)", "DeepSAD (Efficient)", "IsoForest", "OC-SVM"]: v = median_std_by_method.get(name, 0.0) cap_parts.append(f"{name} {v:.3f}") cap_str = "; ".join(cap_parts) lines.append( rf"\caption{{AP means across 5 folds for both evaluations, grouped by labeling regime. " rf"Maximum observed standard deviation per method (not shown in table): {cap_str}.}}" ) lines.append(r"\end{table}") return "\n".join(lines), max_std def main(): # Load full results DF (cache behavior handled by your loader) df = load_results_dataframe(ROOT, allow_cache=True) df = _with_net_label(df) df = _filter_base(df) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) archive_dir = OUTPUT_DIR / "archive" archive_dir.mkdir(parents=True, exist_ok=True) ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) # Pre-compute aggregated cells (mean/std) for all evals/regimes cells = _compute_cells(df) # Build the single big table tex, max_std = _build_single_table( cells, semi_labeling_regimes=SEMI_LABELING_REGIMES ) out_name = "ap_table_all_evals_all_regimes.tex" out_path = ts_dir / out_name out_path.write_text(tex, encoding="utf-8") # Build experiment-based table with mean ± std tex_exp = _build_exp_based_table(cells, semi_labeling_regimes=SEMI_LABELING_REGIMES) out_name_exp = "ap_table_exp_based_mean_std.tex" out_path_exp = ts_dir / out_name_exp out_path_exp.write_text(tex_exp, encoding="utf-8") # Copy this script to preserve the code used for the outputs script_path = Path(__file__) shutil.copy2(script_path, ts_dir / script_path.name) # Mirror latest latest = OUTPUT_DIR / "latest" latest.mkdir(exist_ok=True, parents=True) for f in latest.iterdir(): if f.is_file(): f.unlink() for f in ts_dir.iterdir(): if f.is_file(): shutil.copy2(f, latest / f.name) print(f"Saved table to: {ts_dir}") print(f"Also updated: {latest}") print(f" - {out_name}") if __name__ == "__main__": main()