wip
This commit is contained in:
@@ -26,7 +26,8 @@ SCHEMA_STATIC = {
|
||||
"eval": pl.Utf8, # "exp_based" | "manual_based"
|
||||
"fold": pl.Int32,
|
||||
# metrics
|
||||
"auc": pl.Float64,
|
||||
"roc_auc": pl.Float64, # <-- renamed from 'auc'
|
||||
"prc_auc": pl.Float64, # <-- new
|
||||
"ap": pl.Float64,
|
||||
# per-sample scores: list of (idx, label, score)
|
||||
"scores": pl.List(
|
||||
@@ -114,6 +115,43 @@ SCHEMA_INFERENCE = {
|
||||
# ------------------------------------------------------------
|
||||
# Helpers: curve/scores normalizers (tuples/ndarrays -> dict/list)
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_prc_auc_from_curve(prc_curve: dict | None) -> float | None:
|
||||
"""
|
||||
Compute AUC of the Precision-Recall curve via trapezoidal rule.
|
||||
Expects prc_curve = {"precision": [...], "recall": [...], "thr": [...] (optional)}.
|
||||
Robust to NaNs, unsorted recall, and missing endpoints; returns np.nan if empty.
|
||||
"""
|
||||
if not prc_curve:
|
||||
return np.nan
|
||||
precision = np.asarray(prc_curve.get("precision", []), dtype=float)
|
||||
recall = np.asarray(prc_curve.get("recall", []), dtype=float)
|
||||
if precision.size == 0 or recall.size == 0:
|
||||
return np.nan
|
||||
|
||||
mask = ~(np.isnan(precision) | np.isnan(recall))
|
||||
precision, recall = precision[mask], recall[mask]
|
||||
if recall.size == 0:
|
||||
return np.nan
|
||||
|
||||
# Sort by recall, clip to [0,1]
|
||||
order = np.argsort(recall)
|
||||
recall = np.clip(recall[order], 0.0, 1.0)
|
||||
precision = np.clip(precision[order], 0.0, 1.0)
|
||||
|
||||
# Ensure curve spans [0,1] in recall (hold precision constant at ends)
|
||||
if recall[0] > 0.0:
|
||||
recall = np.insert(recall, 0, 0.0)
|
||||
precision = np.insert(precision, 0, precision[0])
|
||||
if recall[-1] < 1.0:
|
||||
recall = np.append(recall, 1.0)
|
||||
precision = np.append(precision, precision[-1])
|
||||
|
||||
# Trapezoidal AUC
|
||||
return float(np.trapezoid(precision, recall))
|
||||
|
||||
|
||||
def _tolist(x):
|
||||
if x is None:
|
||||
return None
|
||||
@@ -357,23 +395,28 @@ def rows_from_ocsvm_default(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||
# Build the Polars DataFrame
|
||||
# ------------------------------------------------------------
|
||||
def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame:
|
||||
"""
|
||||
Walks experiment subdirs under `root`. For each (model, fold) it adds rows:
|
||||
Columns (SCHEMA_STATIC):
|
||||
network, latent_dim, semi_normals, semi_anomalous,
|
||||
model, eval, fold,
|
||||
auc, ap, scores{sample_idx,orig_label,score},
|
||||
roc_curve{fpr,tpr,thr}, prc_curve{precision,recall,thr},
|
||||
sample_indices, sample_labels, valid_mask,
|
||||
train_time, test_time,
|
||||
folder, k_fold_num
|
||||
"""
|
||||
if allow_cache:
|
||||
cache = root / "results_cache.parquet"
|
||||
if cache.exists():
|
||||
try:
|
||||
df = pl.read_parquet(cache)
|
||||
print(f"[info] loaded cached results frame from {cache}")
|
||||
# Backward-compat: old caches may have 'auc' but no 'roc_auc'/'prc_auc'
|
||||
if "roc_auc" not in df.columns and "auc" in df.columns:
|
||||
df = df.rename({"auc": "roc_auc"})
|
||||
if "prc_auc" not in df.columns and "prc_curve" in df.columns:
|
||||
df = df.with_columns(
|
||||
pl.struct(
|
||||
pl.col("prc_curve").struct.field("precision"),
|
||||
pl.col("prc_curve").struct.field("recall"),
|
||||
)
|
||||
.map_elements(
|
||||
lambda s: compute_prc_auc_from_curve(
|
||||
{"precision": s[0], "recall": s[1]}
|
||||
)
|
||||
)
|
||||
.alias("prc_auc")
|
||||
)
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"[warn] failed to load cache {cache}: {e}")
|
||||
@@ -408,15 +451,17 @@ def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame
|
||||
continue
|
||||
|
||||
if model == "deepsad":
|
||||
per_eval = rows_from_deepsad(data, EVALS) # eval -> dict
|
||||
per_eval = rows_from_deepsad(data, EVALS)
|
||||
elif model == "isoforest":
|
||||
per_eval = rows_from_isoforest(data, EVALS) # eval -> dict
|
||||
per_eval = rows_from_isoforest(data, EVALS)
|
||||
elif model == "ocsvm":
|
||||
per_eval = rows_from_ocsvm_default(data, EVALS) # eval -> dict
|
||||
per_eval = rows_from_ocsvm_default(data, EVALS)
|
||||
else:
|
||||
per_eval = {}
|
||||
|
||||
for ev, vals in per_eval.items():
|
||||
# compute prc_auc now (fast), rename auc->roc_auc
|
||||
prc_auc_val = compute_prc_auc_from_curve(vals.get("prc"))
|
||||
rows.append(
|
||||
{
|
||||
"network": network,
|
||||
@@ -426,7 +471,8 @@ def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame
|
||||
"model": model,
|
||||
"eval": ev,
|
||||
"fold": fold,
|
||||
"auc": vals["auc"],
|
||||
"roc_auc": vals["auc"], # renamed
|
||||
"prc_auc": prc_auc_val, # new
|
||||
"ap": vals["ap"],
|
||||
"scores": vals["scores"],
|
||||
"roc_curve": vals["roc"],
|
||||
@@ -442,20 +488,19 @@ def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame
|
||||
}
|
||||
)
|
||||
|
||||
# If empty, return a typed empty frame
|
||||
if not rows:
|
||||
# Return a typed empty frame (new schema)
|
||||
return pl.DataFrame(schema=SCHEMA_STATIC)
|
||||
|
||||
df = pl.DataFrame(rows, schema=SCHEMA_STATIC)
|
||||
|
||||
# Cast to efficient dtypes (categoricals etc.) – no extra sanitation
|
||||
# Cast to efficient dtypes (categoricals etc.)
|
||||
df = df.with_columns(
|
||||
pl.col("network", "model", "eval").cast(pl.Categorical),
|
||||
pl.col(
|
||||
"latent_dim", "semi_normals", "semi_anomalous", "fold", "k_fold_num"
|
||||
).cast(pl.Int32),
|
||||
pl.col("auc", "ap", "train_time", "test_time").cast(pl.Float64),
|
||||
# NOTE: no cast on 'scores' here; it's already List(Struct) per schema.
|
||||
pl.col("roc_auc", "prc_auc", "ap", "train_time", "test_time").cast(pl.Float64),
|
||||
)
|
||||
|
||||
if allow_cache:
|
||||
|
||||
Reference in New Issue
Block a user