72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import polars as pl
|
||
|
|
|
||
|
|
from load_results import load_pretraining_results_dataframe, load_results_dataframe
|
||
|
|
|
||
|
|
|
||
|
|
# ------------------------------------------------------------
|
||
|
|
# Example “analysis-ready” queries (Polars idioms)
|
||
|
|
# ------------------------------------------------------------
|
||
|
|
def demo_queries(df: pl.DataFrame):
|
||
|
|
# q1: lazy is fine, then collect
|
||
|
|
q1 = (
|
||
|
|
df.lazy()
|
||
|
|
.filter(
|
||
|
|
(pl.col("network") == "LeNet")
|
||
|
|
& (pl.col("latent_dim") == 1024)
|
||
|
|
& (pl.col("semi_normals") == 0)
|
||
|
|
& (pl.col("semi_anomalous") == 0)
|
||
|
|
& (pl.col("eval") == "exp_based")
|
||
|
|
)
|
||
|
|
.group_by(["model"])
|
||
|
|
.agg(pl.col("auc").mean().alias("mean_auc"))
|
||
|
|
.sort(["mean_auc"], descending=True)
|
||
|
|
.collect()
|
||
|
|
)
|
||
|
|
|
||
|
|
# q2: do the filtering eagerly, then pivot (LazyFrame has no .pivot)
|
||
|
|
base = df.filter(
|
||
|
|
(pl.col("model") == "deepsad")
|
||
|
|
& (pl.col("eval") == "exp_based")
|
||
|
|
& (pl.col("network") == "LeNet")
|
||
|
|
& (pl.col("semi_normals") == 0)
|
||
|
|
& (pl.col("semi_anomalous") == 0)
|
||
|
|
).select("fold", "latent_dim", "auc")
|
||
|
|
q2 = base.pivot(
|
||
|
|
values="auc",
|
||
|
|
index="fold",
|
||
|
|
columns="latent_dim",
|
||
|
|
aggregate_function="first", # or "mean" if duplicates exist
|
||
|
|
).sort("fold")
|
||
|
|
|
||
|
|
# roc_subset: eager filter/select, then explode struct fields
|
||
|
|
roc_subset = (
|
||
|
|
df.filter(
|
||
|
|
(pl.col("model") == "ocsvm")
|
||
|
|
& (pl.col("eval") == "manual_based")
|
||
|
|
& (pl.col("network") == "efficient")
|
||
|
|
& (pl.col("latent_dim") == 1024)
|
||
|
|
& (pl.col("semi_normals") == 0)
|
||
|
|
& (pl.col("semi_anomalous") == 0)
|
||
|
|
)
|
||
|
|
.select("fold", "roc_curve")
|
||
|
|
.with_columns(
|
||
|
|
pl.col("roc_curve").struct.field("fpr").alias("fpr"),
|
||
|
|
pl.col("roc_curve").struct.field("tpr").alias("tpr"),
|
||
|
|
pl.col("roc_curve").struct.field("thr").alias("thr"),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return q1, q2, roc_subset
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
root = Path("/home/fedex/mt/results/done")
|
||
|
|
df = load_results_dataframe(root, allow_cache=True)
|
||
|
|
demo_queries(df)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|