diff --git a/tools/plot_scripts/results_latent_space_tables.py b/tools/plot_scripts/results_latent_space_tables.py index 4fff010..764b5bc 100644 --- a/tools/plot_scripts/results_latent_space_tables.py +++ b/tools/plot_scripts/results_latent_space_tables.py @@ -16,11 +16,11 @@ from load_results import load_results_dataframe ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") -# Semi-labeling regimes (semi_normals, semi_anomalous) +# Semi-labeling regimes (semi_normals, semi_anomalous) in display order SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] -# Which evaluation columns to include (one table per eval × semi-regime) -EVALS: list[str] = ["exp_based", "manual_based"] +# Both evals are shown side-by-side in one table +EVALS_BOTH: tuple[str, str] = ("exp_based", "manual_based") # Row order (latent dims) LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] @@ -30,13 +30,12 @@ LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] METHOD_COLUMNS = [ ("deepsad", "LeNet"), # DeepSAD (LeNet) ("deepsad", "Efficient"), # DeepSAD (Efficient) - ("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline) - ("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline) + ("isoforest", "Efficient"), # IsolationForest (Efficient baseline) + ("ocsvm", "Efficient"), # OC-SVM (Efficient baseline) ] # Formatting -DECIMALS = 3 # number of decimals for mean/std -STD_FMT = r"\textpm" # between mean and std in LaTeX +DECIMALS = 3 # cells look like 1.000 or 0.928 (3 decimals) # ---------------------------- @@ -58,140 +57,229 @@ def _with_net_label(df: pl.DataFrame) -> pl.DataFrame: ) -def _filter_base( - df: pl.DataFrame, - *, - eval_type: str, - semi_normals: int, - semi_anomalous: int, -) -> pl.DataFrame: - """Common filtering by regime/eval/valid dims&models.""" +def _filter_base(df: pl.DataFrame) -> pl.DataFrame: + """Restrict to valid dims/models and needed columns (no eval/regime filtering here).""" return df.filter( - (pl.col("semi_normals") == semi_normals) - & (pl.col("semi_anomalous") == semi_anomalous) - & (pl.col("eval") == eval_type) - & (pl.col("latent_dim").is_in(LATENT_DIMS)) + (pl.col("latent_dim").is_in(LATENT_DIMS)) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) + & (pl.col("eval").is_in(list(EVALS_BOTH))) ).select( "model", "net_label", "latent_dim", "fold", "auc", + "eval", + "semi_normals", + "semi_anomalous", ) -def _format_mean_std(mean: float | None, std: float | None) -> str: - if mean is None or (mean != mean): # NaN check - return "--" - if std is None or (std != std): - return f"{mean:.{DECIMALS}f}" - return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}" - - @dataclass(frozen=True) class Cell: mean: float | None std: float | None -def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]: +def _compute_cells(df: pl.DataFrame) -> dict[tuple[str, int, str, str, int, int], Cell]: """ - Compute per-(latent_dim, model, net_label) mean/std for AUC across folds. - Returns a dict keyed by (latent_dim, model, net_label). + Compute per-(eval, latent_dim, model, net_label, semi_normals, semi_anomalous) + mean/std for AUC across folds. """ if df.is_empty(): return {} + + # For baselines (isoforest/ocsvm) constrain to Efficient backbone + df = df.filter( + pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) + .then(pl.col("net_label") == "Efficient") + .otherwise(True) + ) + agg = ( - df.group_by(["latent_dim", "model", "net_label"]) + df.group_by( + [ + "eval", + "latent_dim", + "model", + "net_label", + "semi_normals", + "semi_anomalous", + ] + ) .agg( pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc") ) .to_dicts() ) - out: dict[tuple[int, str, str], Cell] = {} + + out: dict[tuple[str, int, str, str, int, int], Cell] = {} for row in agg: - key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"])) + key = ( + str(row["eval"]), + int(row["latent_dim"]), + str(row["model"]), + str(row["net_label"]), + int(row["semi_normals"]), + int(row["semi_anomalous"]), + ) out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc")) return out -def _bold_best_in_row(values: list[float | None]) -> list[bool]: - """Return a mask of which entries are the (tied) maximum among non-None values.""" - clean = [(v if (v is not None and v == v) else None) for v in values] - finite_vals = [v for v in clean if v is not None] - if not finite_vals: +def _fmt_mean(mean: float | None) -> str: + return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}" + + +def _bold_best_mask_display(values: list[float | None], decimals: int) -> list[bool]: + """ + Bolding mask based on *displayed* precision. Any entries that round (via f-string) + to the maximum at 'decimals' places are bolded (ties bolded). + """ + + def disp(v: float | None) -> float | None: + if v is None or not (v == v): + return None + return float(f"{v:.{decimals}f}") + + rounded = [disp(v) for v in values] + finite = [v for v in rounded if v is not None] + if not finite: return [False] * len(values) - maxv = max(finite_vals) - return [(v is not None and abs(v - maxv) < 1e-12) for v in clean] + maxv = max(finite) + return [(v is not None and v == maxv) for v in rounded] -def _latex_table( - cells: dict[tuple[int, str, str], Cell], +def _build_single_table( + cells: dict[tuple[str, int, str, str, int, int], Cell], *, - eval_type: str, - semi_normals: int, - semi_anomalous: int, -) -> str: + semi_labeling_regimes: list[tuple[int, int]], +) -> tuple[str, float | None]: """ - Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS. - Bold best AUC (mean) per row. + Build the LaTeX table string with grouped headers and regime blocks. + Returns (latex, max_std_overall). """ + + # Rotated header labels (90° slanted) header_cols = [ - r"\textbf{DeepSAD (LeNet)}", - r"\textbf{DeepSAD (Efficient)}", - r"\textbf{IsolationForest}", - r"\textbf{OC\text{-}SVM}", + r"\rotheader{DeepSAD\\(LeNet)}", + r"\rotheader{DeepSAD\\(Efficient)}", + r"\rotheader{IsoForest}", + r"\rotheader{OC-SVM}", ] - eval_type_str = ( - "experiment-based evaluation" - if eval_type == "exp_based" - else "handlabeling-based evaluation" - ) + # Track max std across all cells + max_std: float | None = None + + def push_std(std_val: float | None): + nonlocal max_std + if std_val is None or not (std_val == std_val): + return + if max_std is None or std_val > max_std: + max_std = std_val lines: list[str] = [] + + # Table preamble / structure lines.append(r"\begin{table}[t]") lines.append(r"\centering") - lines.append( - rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, " - rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}" - ) - lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}") - lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}") + lines.append(r"\setlength{\tabcolsep}{4pt}") + lines.append(r"\renewcommand{\arraystretch}{1.2}") + # Vertical rule between the two groups for data/header rows: + lines.append(r"\begin{tabularx}{\textwidth}{c*{4}{Y}|*{4}{Y}}") lines.append(r"\toprule") - lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\") + lines.append( + r" & \multicolumn{4}{c}{Experiment-based eval.} & \multicolumn{4}{c}{Handlabeled eval.} \\" + ) + lines.append(r"\cmidrule(lr){2-5} \cmidrule(lr){6-9}") + lines.append( + r"Latent Dim. & " + + " & ".join(header_cols) + + " & " + + " & ".join(header_cols) + + r" \\" + ) lines.append(r"\midrule") - for dim in LATENT_DIMS: - # Collect means for bolding - means_for_bold: list[float | None] = [] - cell_strs: list[str] = [] - for model, net in METHOD_COLUMNS: - cell = cells.get((dim, model, net), Cell(None, None)) - means_for_bold.append(cell.mean) - cell_strs.append(_format_mean_std(cell.mean, cell.std)) + # Iterate regimes and rows + for idx, (semi_n, semi_a) in enumerate(semi_labeling_regimes): + # Regime label row (multicolumn suppresses the vertical bar in this row) + lines.append( + rf"\multicolumn{{9}}{{l}}{{\textbf{{Labeling regime: }}\(\mathbf{{{semi_n}/{semi_a}}}\) " + rf"\textit{{(normal/anomalous samples labeled)}}}} \\" + ) + lines.append(r"\addlinespace[2pt]") - bold_mask = _bold_best_in_row(means_for_bold) - pretty_cells: list[str] = [] - for s, do_bold in zip(cell_strs, bold_mask): - if do_bold and s != "--": - pretty_cells.append(r"\textbf{" + s + r"}") - else: - pretty_cells.append(s) + for dim in LATENT_DIMS: + # Values in order: left group (exp_based) 4 cols, right group (manual_based) 4 cols + means_left: list[float | None] = [] + means_right: list[float | None] = [] + cell_strs_left: list[str] = [] + cell_strs_right: list[str] = [] - lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\") + # Left group: exp_based + eval_type = EVALS_BOTH[0] + for model, net in METHOD_COLUMNS: + key = (eval_type, dim, model, net, semi_n, semi_a) + cell = cells.get(key, Cell(None, None)) + means_left.append(cell.mean) + cell_strs_left.append(_fmt_mean(cell.mean)) + push_std(cell.std) + + # Right group: manual_based + eval_type = EVALS_BOTH[1] + for model, net in METHOD_COLUMNS: + key = (eval_type, dim, model, net, semi_n, semi_a) + cell = cells.get(key, Cell(None, None)) + means_right.append(cell.mean) + cell_strs_right.append(_fmt_mean(cell.mean)) + push_std(cell.std) + + # Bolding per group based on displayed precision + mask_left = _bold_best_mask_display(means_left, DECIMALS) + mask_right = _bold_best_mask_display(means_right, DECIMALS) + + pretty_left = [ + (r"\textbf{" + s + "}") if (do_bold and s != "--") else s + for s, do_bold in zip(cell_strs_left, mask_left) + ] + pretty_right = [ + (r"\textbf{" + s + "}") if (do_bold and s != "--") else s + for s, do_bold in zip(cell_strs_right, mask_right) + ] + + # Join with the vertical bar between groups automatically handled by column spec + lines.append( + f"{dim} & " + + " & ".join(pretty_left) + + " & " + + " & ".join(pretty_right) + + r" \\" + ) + + # Separator between regime blocks (but not after the last one) + if idx < len(semi_labeling_regimes) - 1: + lines.append(r"\midrule") lines.append(r"\bottomrule") lines.append(r"\end{tabularx}") + + # Caption with max std (not shown in table) + max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}" + lines.append( + rf"\caption{{AUC means across 5 folds for both evaluations, grouped by labeling regime. " + rf"Maximum observed standard deviation across all cells (not shown in table): {max_std_str}.}}" + ) lines.append(r"\end{table}") - return "\n".join(lines) + + return "\n".join(lines), max_std def main(): # Load full results DF (cache behavior handled by your loader) df = load_results_dataframe(ROOT, allow_cache=True) df = _with_net_label(df) + df = _filter_base(df) # Prepare output dirs OUTPUT_DIR.mkdir(parents=True, exist_ok=True) @@ -200,36 +288,17 @@ def main(): ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir.mkdir(parents=True, exist_ok=True) - emitted_files: list[Path] = [] + # Pre-compute aggregated cells (mean/std) for all evals/regimes + cells = _compute_cells(df) - for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES: - for eval_type in EVALS: - sub = _filter_base( - df, - eval_type=eval_type, - semi_normals=semi_normals, - semi_anomalous=semi_anomalous, - ) + # Build the single big table + tex, max_std = _build_single_table( + cells, semi_labeling_regimes=SEMI_LABELING_REGIMES + ) - # For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots - sub = sub.filter( - pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) - .then(pl.col("net_label") == "Efficient") - .otherwise(True) - ) - - cells = _compute_cells(sub) - tex = _latex_table( - cells, - eval_type=eval_type, - semi_normals=semi_normals, - semi_anomalous=semi_anomalous, - ) - - out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex" - out_path = ts_dir / out_name - out_path.write_text(tex, encoding="utf-8") - emitted_files.append(out_path) + out_name = "auc_table_all_evals_all_regimes.tex" + out_path = ts_dir / out_name + out_path.write_text(tex, encoding="utf-8") # Copy this script to preserve the code used for the outputs script_path = Path(__file__) @@ -245,10 +314,9 @@ def main(): if f.is_file(): shutil.copy2(f, latest / f.name) - print(f"Saved tables to: {ts_dir}") + print(f"Saved table to: {ts_dir}") print(f"Also updated: {latest}") - for p in emitted_files: - print(f" - {p.name}") + print(f" - {out_name}") if __name__ == "__main__":