correct auc table scrip

This commit is contained in:
Jan Kowalczyk
2025-09-17 11:43:26 +02:00
parent 95867bde7a
commit 936d2ecb6e

View File

@@ -16,11 +16,11 @@ from load_results import load_results_dataframe
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables") OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables")
# Semi-labeling regimes (semi_normals, semi_anomalous) # Semi-labeling regimes (semi_normals, semi_anomalous) in display order
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)] SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
# Which evaluation columns to include (one table per eval × semi-regime) # Both evals are shown side-by-side in one table
EVALS: list[str] = ["exp_based", "manual_based"] EVALS_BOTH: tuple[str, str] = ("exp_based", "manual_based")
# Row order (latent dims) # Row order (latent dims)
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024] LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
@@ -30,13 +30,12 @@ LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
METHOD_COLUMNS = [ METHOD_COLUMNS = [
("deepsad", "LeNet"), # DeepSAD (LeNet) ("deepsad", "LeNet"), # DeepSAD (LeNet)
("deepsad", "Efficient"), # DeepSAD (Efficient) ("deepsad", "Efficient"), # DeepSAD (Efficient)
("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline) ("isoforest", "Efficient"), # IsolationForest (Efficient baseline)
("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline) ("ocsvm", "Efficient"), # OC-SVM (Efficient baseline)
] ]
# Formatting # Formatting
DECIMALS = 3 # number of decimals for mean/std DECIMALS = 3 # cells look like 1.000 or 0.928 (3 decimals)
STD_FMT = r"\textpm" # between mean and std in LaTeX
# ---------------------------- # ----------------------------
@@ -58,140 +57,229 @@ def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
) )
def _filter_base( def _filter_base(df: pl.DataFrame) -> pl.DataFrame:
df: pl.DataFrame, """Restrict to valid dims/models and needed columns (no eval/regime filtering here)."""
*,
eval_type: str,
semi_normals: int,
semi_anomalous: int,
) -> pl.DataFrame:
"""Common filtering by regime/eval/valid dims&models."""
return df.filter( return df.filter(
(pl.col("semi_normals") == semi_normals) (pl.col("latent_dim").is_in(LATENT_DIMS))
& (pl.col("semi_anomalous") == semi_anomalous)
& (pl.col("eval") == eval_type)
& (pl.col("latent_dim").is_in(LATENT_DIMS))
& (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"])) & (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
& (pl.col("eval").is_in(list(EVALS_BOTH)))
).select( ).select(
"model", "model",
"net_label", "net_label",
"latent_dim", "latent_dim",
"fold", "fold",
"auc", "auc",
"eval",
"semi_normals",
"semi_anomalous",
) )
def _format_mean_std(mean: float | None, std: float | None) -> str:
if mean is None or (mean != mean): # NaN check
return "--"
if std is None or (std != std):
return f"{mean:.{DECIMALS}f}"
return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}"
@dataclass(frozen=True) @dataclass(frozen=True)
class Cell: class Cell:
mean: float | None mean: float | None
std: float | None std: float | None
def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]: def _compute_cells(df: pl.DataFrame) -> dict[tuple[str, int, str, str, int, int], Cell]:
""" """
Compute per-(latent_dim, model, net_label) mean/std for AUC across folds. Compute per-(eval, latent_dim, model, net_label, semi_normals, semi_anomalous)
Returns a dict keyed by (latent_dim, model, net_label). mean/std for AUC across folds.
""" """
if df.is_empty(): if df.is_empty():
return {} return {}
# For baselines (isoforest/ocsvm) constrain to Efficient backbone
df = df.filter(
pl.when(pl.col("model").is_in(["isoforest", "ocsvm"]))
.then(pl.col("net_label") == "Efficient")
.otherwise(True)
)
agg = ( agg = (
df.group_by(["latent_dim", "model", "net_label"]) df.group_by(
[
"eval",
"latent_dim",
"model",
"net_label",
"semi_normals",
"semi_anomalous",
]
)
.agg( .agg(
pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc") pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc")
) )
.to_dicts() .to_dicts()
) )
out: dict[tuple[int, str, str], Cell] = {}
out: dict[tuple[str, int, str, str, int, int], Cell] = {}
for row in agg: for row in agg:
key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"])) key = (
str(row["eval"]),
int(row["latent_dim"]),
str(row["model"]),
str(row["net_label"]),
int(row["semi_normals"]),
int(row["semi_anomalous"]),
)
out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc")) out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc"))
return out return out
def _bold_best_in_row(values: list[float | None]) -> list[bool]: def _fmt_mean(mean: float | None) -> str:
"""Return a mask of which entries are the (tied) maximum among non-None values.""" return "--" if (mean is None or not (mean == mean)) else f"{mean:.{DECIMALS}f}"
clean = [(v if (v is not None and v == v) else None) for v in values]
finite_vals = [v for v in clean if v is not None]
if not finite_vals: def _bold_best_mask_display(values: list[float | None], decimals: int) -> list[bool]:
"""
Bolding mask based on *displayed* precision. Any entries that round (via f-string)
to the maximum at 'decimals' places are bolded (ties bolded).
"""
def disp(v: float | None) -> float | None:
if v is None or not (v == v):
return None
return float(f"{v:.{decimals}f}")
rounded = [disp(v) for v in values]
finite = [v for v in rounded if v is not None]
if not finite:
return [False] * len(values) return [False] * len(values)
maxv = max(finite_vals) maxv = max(finite)
return [(v is not None and abs(v - maxv) < 1e-12) for v in clean] return [(v is not None and v == maxv) for v in rounded]
def _latex_table( def _build_single_table(
cells: dict[tuple[int, str, str], Cell], cells: dict[tuple[str, int, str, str, int, int], Cell],
*, *,
eval_type: str, semi_labeling_regimes: list[tuple[int, int]],
semi_normals: int, ) -> tuple[str, float | None]:
semi_anomalous: int,
) -> str:
""" """
Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS. Build the LaTeX table string with grouped headers and regime blocks.
Bold best AUC (mean) per row. Returns (latex, max_std_overall).
""" """
# Rotated header labels (90° slanted)
header_cols = [ header_cols = [
r"\textbf{DeepSAD (LeNet)}", r"\rotheader{DeepSAD\\(LeNet)}",
r"\textbf{DeepSAD (Efficient)}", r"\rotheader{DeepSAD\\(Efficient)}",
r"\textbf{IsolationForest}", r"\rotheader{IsoForest}",
r"\textbf{OC\text{-}SVM}", r"\rotheader{OC-SVM}",
] ]
eval_type_str = ( # Track max std across all cells
"experiment-based evaluation" max_std: float | None = None
if eval_type == "exp_based"
else "handlabeling-based evaluation" def push_std(std_val: float | None):
) nonlocal max_std
if std_val is None or not (std_val == std_val):
return
if max_std is None or std_val > max_std:
max_std = std_val
lines: list[str] = [] lines: list[str] = []
# Table preamble / structure
lines.append(r"\begin{table}[t]") lines.append(r"\begin{table}[t]")
lines.append(r"\centering") lines.append(r"\centering")
lines.append( lines.append(r"\setlength{\tabcolsep}{4pt}")
rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, " lines.append(r"\renewcommand{\arraystretch}{1.2}")
rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}" # Vertical rule between the two groups for data/header rows:
) lines.append(r"\begin{tabularx}{\textwidth}{c*{4}{Y}|*{4}{Y}}")
lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}")
lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}")
lines.append(r"\toprule") lines.append(r"\toprule")
lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\") lines.append(
r" & \multicolumn{4}{c}{Experiment-based eval.} & \multicolumn{4}{c}{Handlabeled eval.} \\"
)
lines.append(r"\cmidrule(lr){2-5} \cmidrule(lr){6-9}")
lines.append(
r"Latent Dim. & "
+ " & ".join(header_cols)
+ " & "
+ " & ".join(header_cols)
+ r" \\"
)
lines.append(r"\midrule") lines.append(r"\midrule")
for dim in LATENT_DIMS: # Iterate regimes and rows
# Collect means for bolding for idx, (semi_n, semi_a) in enumerate(semi_labeling_regimes):
means_for_bold: list[float | None] = [] # Regime label row (multicolumn suppresses the vertical bar in this row)
cell_strs: list[str] = [] lines.append(
for model, net in METHOD_COLUMNS: rf"\multicolumn{{9}}{{l}}{{\textbf{{Labeling regime: }}\(\mathbf{{{semi_n}/{semi_a}}}\) "
cell = cells.get((dim, model, net), Cell(None, None)) rf"\textit{{(normal/anomalous samples labeled)}}}} \\"
means_for_bold.append(cell.mean) )
cell_strs.append(_format_mean_std(cell.mean, cell.std)) lines.append(r"\addlinespace[2pt]")
bold_mask = _bold_best_in_row(means_for_bold) for dim in LATENT_DIMS:
pretty_cells: list[str] = [] # Values in order: left group (exp_based) 4 cols, right group (manual_based) 4 cols
for s, do_bold in zip(cell_strs, bold_mask): means_left: list[float | None] = []
if do_bold and s != "--": means_right: list[float | None] = []
pretty_cells.append(r"\textbf{" + s + r"}") cell_strs_left: list[str] = []
else: cell_strs_right: list[str] = []
pretty_cells.append(s)
lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\") # Left group: exp_based
eval_type = EVALS_BOTH[0]
for model, net in METHOD_COLUMNS:
key = (eval_type, dim, model, net, semi_n, semi_a)
cell = cells.get(key, Cell(None, None))
means_left.append(cell.mean)
cell_strs_left.append(_fmt_mean(cell.mean))
push_std(cell.std)
# Right group: manual_based
eval_type = EVALS_BOTH[1]
for model, net in METHOD_COLUMNS:
key = (eval_type, dim, model, net, semi_n, semi_a)
cell = cells.get(key, Cell(None, None))
means_right.append(cell.mean)
cell_strs_right.append(_fmt_mean(cell.mean))
push_std(cell.std)
# Bolding per group based on displayed precision
mask_left = _bold_best_mask_display(means_left, DECIMALS)
mask_right = _bold_best_mask_display(means_right, DECIMALS)
pretty_left = [
(r"\textbf{" + s + "}") if (do_bold and s != "--") else s
for s, do_bold in zip(cell_strs_left, mask_left)
]
pretty_right = [
(r"\textbf{" + s + "}") if (do_bold and s != "--") else s
for s, do_bold in zip(cell_strs_right, mask_right)
]
# Join with the vertical bar between groups automatically handled by column spec
lines.append(
f"{dim} & "
+ " & ".join(pretty_left)
+ " & "
+ " & ".join(pretty_right)
+ r" \\"
)
# Separator between regime blocks (but not after the last one)
if idx < len(semi_labeling_regimes) - 1:
lines.append(r"\midrule")
lines.append(r"\bottomrule") lines.append(r"\bottomrule")
lines.append(r"\end{tabularx}") lines.append(r"\end{tabularx}")
# Caption with max std (not shown in table)
max_std_str = "n/a" if max_std is None else f"{max_std:.{DECIMALS}f}"
lines.append(
rf"\caption{{AUC means across 5 folds for both evaluations, grouped by labeling regime. "
rf"Maximum observed standard deviation across all cells (not shown in table): {max_std_str}.}}"
)
lines.append(r"\end{table}") lines.append(r"\end{table}")
return "\n".join(lines)
return "\n".join(lines), max_std
def main(): def main():
# Load full results DF (cache behavior handled by your loader) # Load full results DF (cache behavior handled by your loader)
df = load_results_dataframe(ROOT, allow_cache=True) df = load_results_dataframe(ROOT, allow_cache=True)
df = _with_net_label(df) df = _with_net_label(df)
df = _filter_base(df)
# Prepare output dirs # Prepare output dirs
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
@@ -200,36 +288,17 @@ def main():
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True) ts_dir.mkdir(parents=True, exist_ok=True)
emitted_files: list[Path] = [] # Pre-compute aggregated cells (mean/std) for all evals/regimes
cells = _compute_cells(df)
for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES: # Build the single big table
for eval_type in EVALS: tex, max_std = _build_single_table(
sub = _filter_base( cells, semi_labeling_regimes=SEMI_LABELING_REGIMES
df, )
eval_type=eval_type,
semi_normals=semi_normals,
semi_anomalous=semi_anomalous,
)
# For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots out_name = "auc_table_all_evals_all_regimes.tex"
sub = sub.filter( out_path = ts_dir / out_name
pl.when(pl.col("model").is_in(["isoforest", "ocsvm"])) out_path.write_text(tex, encoding="utf-8")
.then(pl.col("net_label") == "Efficient")
.otherwise(True)
)
cells = _compute_cells(sub)
tex = _latex_table(
cells,
eval_type=eval_type,
semi_normals=semi_normals,
semi_anomalous=semi_anomalous,
)
out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex"
out_path = ts_dir / out_name
out_path.write_text(tex, encoding="utf-8")
emitted_files.append(out_path)
# Copy this script to preserve the code used for the outputs # Copy this script to preserve the code used for the outputs
script_path = Path(__file__) script_path = Path(__file__)
@@ -245,10 +314,9 @@ def main():
if f.is_file(): if f.is_file():
shutil.copy2(f, latest / f.name) shutil.copy2(f, latest / f.name)
print(f"Saved tables to: {ts_dir}") print(f"Saved table to: {ts_dir}")
print(f"Also updated: {latest}") print(f"Also updated: {latest}")
for p in emitted_files: print(f" - {out_name}")
print(f" - {p.name}")
if __name__ == "__main__": if __name__ == "__main__":