wip
This commit is contained in:
@@ -3,10 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from polars.testing import assert_frame_equal
|
||||
|
||||
from diff_df import recursive_diff_frames
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Config you can tweak
|
||||
@@ -247,6 +250,14 @@ def read_pickle(p: Path) -> Any:
|
||||
# ------------------------------------------------------------
|
||||
# Extractors for each model
|
||||
# ------------------------------------------------------------
|
||||
|
||||
counting = {
|
||||
(label_method, eval_method): []
|
||||
for label_method in ["exp_based", "manual_based"]
|
||||
for eval_method in ["roc", "prc"]
|
||||
}
|
||||
|
||||
|
||||
def rows_from_deepsad(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||
"""
|
||||
deepsad under data['test'][eval], with extra per-eval arrays and AP present.
|
||||
@@ -257,6 +268,8 @@ def rows_from_deepsad(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||
evd = test.get(ev)
|
||||
if not isinstance(evd, dict):
|
||||
continue
|
||||
counting[(ev, "roc")].append(len(evd["roc"][0]))
|
||||
counting[(ev, "prc")].append(len(evd["prc"][0]))
|
||||
out[ev] = {
|
||||
"auc": float(evd["auc"])
|
||||
if "auc" in evd and evd["auc"] is not None
|
||||
@@ -585,12 +598,53 @@ def load_pretraining_results_dataframe(
|
||||
|
||||
|
||||
def main():
|
||||
root = Path("/home/fedex/mt/results/done")
|
||||
df = load_results_dataframe(root, allow_cache=True)
|
||||
print(df.shape, df.head())
|
||||
root = Path("/home/fedex/mt/results/copy")
|
||||
df1 = load_results_dataframe(root, allow_cache=True)
|
||||
exit(0)
|
||||
|
||||
df_pre = load_pretraining_results_dataframe(root, allow_cache=True)
|
||||
print("pretraining:", df_pre.shape, df_pre.head())
|
||||
retest_root = Path("/home/fedex/mt/results/copy/retest_nodrop")
|
||||
df2 = load_results_dataframe(retest_root, allow_cache=False).drop("folder")
|
||||
|
||||
# exact schema & shape first (optional but helpful messages)
|
||||
assert df1.shape == df2.shape, f"Shape differs: {df1.shape} vs {df2.shape}"
|
||||
assert set(df1.columns) == set(df2.columns), (
|
||||
f"Column sets differ: {df1.columns} vs {df2.columns}"
|
||||
)
|
||||
|
||||
# allow small float diffs, ignore column order differences if you want
|
||||
df1_sorted = df1.select(sorted(df1.columns))
|
||||
df2_sorted = df2.select(sorted(df2.columns))
|
||||
|
||||
# Optionally pre-align/sort both frames by a stable key before diffing.
|
||||
summary, leaves = recursive_diff_frames(
|
||||
df1,
|
||||
df2,
|
||||
ignore=["timestamp"], # columns to ignore
|
||||
float_atol=0.1, # absolute tolerance for floats
|
||||
float_rtol=0.0, # relative tolerance for floats
|
||||
max_rows_per_column=20, # limit expansion per column
|
||||
max_leafs_per_row=200, # cap leaves per row
|
||||
)
|
||||
|
||||
pl.Config.set_fmt_table_cell_list_len(100)
|
||||
pl.Config.set_tbl_rows(100)
|
||||
|
||||
print(summary) # which columns differ & how many rows
|
||||
print(leaves) # exact nested paths + scalar diffs
|
||||
|
||||
# check_exact=False lets us use atol/rtol for floats
|
||||
assert_frame_equal(
|
||||
df1_sorted,
|
||||
df2_sorted,
|
||||
check_exact=False,
|
||||
atol=0.1, # absolute tolerance for floats
|
||||
rtol=0.0, # relative tolerance (set if you want % based)
|
||||
check_dtypes=True, # set False if you only care about values
|
||||
)
|
||||
print("DataFrames match within tolerance ✅")
|
||||
|
||||
# df_pre = load_pretraining_results_dataframe(root, allow_cache=True)
|
||||
# print("pretraining:", df_pre.shape, df_pre.head())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user