256 lines
8.1 KiB
Python
256 lines
8.1 KiB
Python
from __future__ import annotations
|
||
|
||
import shutil
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import polars as pl
|
||
|
||
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||
from load_results import load_results_dataframe
|
||
|
||
# ----------------------------
|
||
# Config
|
||
# ----------------------------
|
||
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
|
||
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables")
|
||
|
||
# Semi-labeling regimes (semi_normals, semi_anomalous)
|
||
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
|
||
|
||
# Which evaluation columns to include (one table per eval × semi-regime)
|
||
EVALS: list[str] = ["exp_based", "manual_based"]
|
||
|
||
# Row order (latent dims)
|
||
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
|
||
|
||
# Column order (method shown to the user)
|
||
# We split DeepSAD into the two network backbones, like your plots.
|
||
METHOD_COLUMNS = [
|
||
("deepsad", "LeNet"), # DeepSAD (LeNet)
|
||
("deepsad", "Efficient"), # DeepSAD (Efficient)
|
||
("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline)
|
||
("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline)
|
||
]
|
||
|
||
# Formatting
|
||
DECIMALS = 3 # number of decimals for mean/std
|
||
STD_FMT = r"\textpm" # between mean and std in LaTeX
|
||
|
||
|
||
# ----------------------------
|
||
# Helpers
|
||
# ----------------------------
|
||
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
|
||
"""Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback)."""
|
||
return df.with_columns(
|
||
pl.when(
|
||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
||
)
|
||
.then(pl.lit("LeNet"))
|
||
.when(
|
||
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
||
)
|
||
.then(pl.lit("Efficient"))
|
||
.otherwise(pl.col("network").cast(pl.Utf8))
|
||
.alias("net_label")
|
||
)
|
||
|
||
|
||
def _filter_base(
|
||
df: pl.DataFrame,
|
||
*,
|
||
eval_type: str,
|
||
semi_normals: int,
|
||
semi_anomalous: int,
|
||
) -> pl.DataFrame:
|
||
"""Common filtering by regime/eval/valid dims&models."""
|
||
return df.filter(
|
||
(pl.col("semi_normals") == semi_normals)
|
||
& (pl.col("semi_anomalous") == semi_anomalous)
|
||
& (pl.col("eval") == eval_type)
|
||
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
||
& (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
|
||
).select(
|
||
"model",
|
||
"net_label",
|
||
"latent_dim",
|
||
"fold",
|
||
"auc",
|
||
)
|
||
|
||
|
||
def _format_mean_std(mean: float | None, std: float | None) -> str:
|
||
if mean is None or (mean != mean): # NaN check
|
||
return "--"
|
||
if std is None or (std != std):
|
||
return f"{mean:.{DECIMALS}f}"
|
||
return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}"
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class Cell:
|
||
mean: float | None
|
||
std: float | None
|
||
|
||
|
||
def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]:
|
||
"""
|
||
Compute per-(latent_dim, model, net_label) mean/std for AUC across folds.
|
||
Returns a dict keyed by (latent_dim, model, net_label).
|
||
"""
|
||
if df.is_empty():
|
||
return {}
|
||
agg = (
|
||
df.group_by(["latent_dim", "model", "net_label"])
|
||
.agg(
|
||
pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc")
|
||
)
|
||
.to_dicts()
|
||
)
|
||
out: dict[tuple[int, str, str], Cell] = {}
|
||
for row in agg:
|
||
key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"]))
|
||
out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc"))
|
||
return out
|
||
|
||
|
||
def _bold_best_in_row(values: list[float | None]) -> list[bool]:
|
||
"""Return a mask of which entries are the (tied) maximum among non-None values."""
|
||
clean = [(v if (v is not None and v == v) else None) for v in values]
|
||
finite_vals = [v for v in clean if v is not None]
|
||
if not finite_vals:
|
||
return [False] * len(values)
|
||
maxv = max(finite_vals)
|
||
return [(v is not None and abs(v - maxv) < 1e-12) for v in clean]
|
||
|
||
|
||
def _latex_table(
|
||
cells: dict[tuple[int, str, str], Cell],
|
||
*,
|
||
eval_type: str,
|
||
semi_normals: int,
|
||
semi_anomalous: int,
|
||
) -> str:
|
||
"""
|
||
Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS.
|
||
Bold best AUC (mean) per row.
|
||
"""
|
||
header_cols = [
|
||
r"\textbf{DeepSAD (LeNet)}",
|
||
r"\textbf{DeepSAD (Efficient)}",
|
||
r"\textbf{IsolationForest}",
|
||
r"\textbf{OC\text{-}SVM}",
|
||
]
|
||
|
||
eval_type_str = (
|
||
"experiment-based evaluation"
|
||
if eval_type == "exp_based"
|
||
else "handlabeling-based evaluation"
|
||
)
|
||
|
||
lines: list[str] = []
|
||
lines.append(r"\begin{table}[t]")
|
||
lines.append(r"\centering")
|
||
lines.append(
|
||
rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, "
|
||
rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}"
|
||
)
|
||
lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}")
|
||
lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}")
|
||
lines.append(r"\toprule")
|
||
lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\")
|
||
lines.append(r"\midrule")
|
||
|
||
for dim in LATENT_DIMS:
|
||
# Collect means for bolding
|
||
means_for_bold: list[float | None] = []
|
||
cell_strs: list[str] = []
|
||
for model, net in METHOD_COLUMNS:
|
||
cell = cells.get((dim, model, net), Cell(None, None))
|
||
means_for_bold.append(cell.mean)
|
||
cell_strs.append(_format_mean_std(cell.mean, cell.std))
|
||
|
||
bold_mask = _bold_best_in_row(means_for_bold)
|
||
pretty_cells: list[str] = []
|
||
for s, do_bold in zip(cell_strs, bold_mask):
|
||
if do_bold and s != "--":
|
||
pretty_cells.append(r"\textbf{" + s + r"}")
|
||
else:
|
||
pretty_cells.append(s)
|
||
|
||
lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\")
|
||
|
||
lines.append(r"\bottomrule")
|
||
lines.append(r"\end{tabularx}")
|
||
lines.append(r"\end{table}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def main():
|
||
# Load full results DF (cache behavior handled by your loader)
|
||
df = load_results_dataframe(ROOT, allow_cache=True)
|
||
df = _with_net_label(df)
|
||
|
||
# Prepare output dirs
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
archive_dir = OUTPUT_DIR / "archive"
|
||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
ts_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
emitted_files: list[Path] = []
|
||
|
||
for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES:
|
||
for eval_type in EVALS:
|
||
sub = _filter_base(
|
||
df,
|
||
eval_type=eval_type,
|
||
semi_normals=semi_normals,
|
||
semi_anomalous=semi_anomalous,
|
||
)
|
||
|
||
# For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots
|
||
sub = sub.filter(
|
||
pl.when(pl.col("model").is_in(["isoforest", "ocsvm"]))
|
||
.then(pl.col("net_label") == "Efficient")
|
||
.otherwise(True)
|
||
)
|
||
|
||
cells = _compute_cells(sub)
|
||
tex = _latex_table(
|
||
cells,
|
||
eval_type=eval_type,
|
||
semi_normals=semi_normals,
|
||
semi_anomalous=semi_anomalous,
|
||
)
|
||
|
||
out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex"
|
||
out_path = ts_dir / out_name
|
||
out_path.write_text(tex, encoding="utf-8")
|
||
emitted_files.append(out_path)
|
||
|
||
# Copy this script to preserve the code used for the outputs
|
||
script_path = Path(__file__)
|
||
shutil.copy2(script_path, ts_dir / script_path.name)
|
||
|
||
# Mirror latest
|
||
latest = OUTPUT_DIR / "latest"
|
||
latest.mkdir(exist_ok=True, parents=True)
|
||
for f in latest.iterdir():
|
||
if f.is_file():
|
||
f.unlink()
|
||
for f in ts_dir.iterdir():
|
||
if f.is_file():
|
||
shutil.copy2(f, latest / f.name)
|
||
|
||
print(f"Saved tables to: {ts_dir}")
|
||
print(f"Also updated: {latest}")
|
||
for p in emitted_files:
|
||
print(f" - {p.name}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|