update
This commit is contained in:
544
tools/plot_scripts/diff_df.py
Normal file
544
tools/plot_scripts/diff_df.py
Normal file
@@ -0,0 +1,544 @@
|
||||
import json
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import polars as pl
|
||||
|
||||
Number = (int, float)
|
||||
|
||||
FLOAT_DTYPES = {pl.Float32, pl.Float64}
|
||||
SIMPLE_CASTABLE_DTYPES = (
|
||||
pl.Int8,
|
||||
pl.Int16,
|
||||
pl.Int32,
|
||||
pl.Int64,
|
||||
pl.UInt8,
|
||||
pl.UInt16,
|
||||
pl.UInt32,
|
||||
pl.UInt64,
|
||||
pl.Float32,
|
||||
pl.Float64,
|
||||
pl.Utf8,
|
||||
pl.Boolean,
|
||||
pl.Date,
|
||||
pl.Datetime,
|
||||
pl.Time,
|
||||
pl.Duration,
|
||||
)
|
||||
|
||||
|
||||
def _is_nan(x):
|
||||
try:
|
||||
return isinstance(x, float) and math.isnan(x)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _repr_safe(v):
|
||||
try:
|
||||
return json.dumps(v, default=str, ensure_ascii=False)
|
||||
except Exception:
|
||||
return repr(v)
|
||||
|
||||
|
||||
def _to_python(v):
|
||||
"""
|
||||
Convert any leaf-ish object to plain Python types:
|
||||
- pl.Series -> list (or scalar if length==1)
|
||||
- objects with .to_list()/.tolist() -> list
|
||||
- dict stays dict; list/tuple become list
|
||||
"""
|
||||
# Polars Series
|
||||
if isinstance(v, pl.Series):
|
||||
seq = v.to_list()
|
||||
return seq[0] if len(seq) == 1 else seq
|
||||
# Numpy scalars/arrays or anything with tolist()
|
||||
if hasattr(v, "tolist"):
|
||||
try:
|
||||
return v.tolist()
|
||||
except Exception:
|
||||
pass
|
||||
# Polars expressions should not appear; stringify them
|
||||
# Anything iterable that isn't list/dict/str -> convert carefully
|
||||
if isinstance(v, tuple):
|
||||
return [_to_python(x) for x in v]
|
||||
if isinstance(v, list):
|
||||
return [_to_python(x) for x in v]
|
||||
if isinstance(v, dict):
|
||||
return {k: _to_python(val) for k, val in v.items()}
|
||||
return v
|
||||
|
||||
|
||||
def _safe_equal(a, b):
|
||||
"""
|
||||
Return a plain bool saying whether a and b are equal,
|
||||
without ever producing a vector/Series.
|
||||
"""
|
||||
# exact same object
|
||||
if a is b:
|
||||
return True
|
||||
# normalize
|
||||
a_n = _to_python(a)
|
||||
b_n = _to_python(b)
|
||||
# handle NaNs
|
||||
if _is_nan(a_n) and _is_nan(b_n):
|
||||
return True
|
||||
# plain scalars/containers
|
||||
try:
|
||||
eq = a_n == b_n
|
||||
if isinstance(eq, bool):
|
||||
return eq
|
||||
except Exception:
|
||||
pass
|
||||
# fallback: compare stable JSON-ish reprs
|
||||
return _repr_safe(a_n) == _repr_safe(b_n)
|
||||
|
||||
|
||||
def _num_close(a: float, b: float, atol: float, rtol: float) -> bool:
|
||||
# NaN==NaN treated equal
|
||||
if _is_nan(a) and _is_nan(b):
|
||||
return True
|
||||
return abs(a - b) <= (atol + rtol * abs(b))
|
||||
|
||||
|
||||
def _to_python(v: Any) -> Any:
|
||||
"""
|
||||
Convert Polars value to a Python object. Struct -> dict, List -> list, scalars stay scalars.
|
||||
Values coming from Series[i] / .to_list() are already Python, so this usually no-ops.
|
||||
"""
|
||||
return v
|
||||
|
||||
|
||||
def _repr_safe(v: Any) -> str:
|
||||
try:
|
||||
return json.dumps(v, default=str, ensure_ascii=False)
|
||||
except Exception:
|
||||
return repr(v)
|
||||
|
||||
|
||||
def _iter_dict_keys(d: Dict[str, Any]) -> Iterable[str]:
|
||||
# stable order, useful for predictable output
|
||||
return sorted(d.keys())
|
||||
|
||||
|
||||
def _recursive_leaf_diffs(a, b, path, out, float_atol, float_rtol):
|
||||
# treat None==None
|
||||
if a is None and b is None:
|
||||
return
|
||||
|
||||
# normalize early
|
||||
a = _to_python(a)
|
||||
b = _to_python(b)
|
||||
|
||||
# tuples -> lists
|
||||
if isinstance(a, tuple):
|
||||
a = list(a)
|
||||
if isinstance(b, tuple):
|
||||
b = list(b)
|
||||
|
||||
# numbers
|
||||
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
||||
if _is_nan(a) and _is_nan(b):
|
||||
return
|
||||
# |a-b| <= atol + rtol*|b|
|
||||
if abs(float(a) - float(b)) > (float_atol + float_rtol * abs(float(b))):
|
||||
out.append(
|
||||
{
|
||||
"path": path or "$",
|
||||
"left": a,
|
||||
"right": b,
|
||||
"abs_delta": abs(float(a) - float(b)),
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# exact types for strings/bools
|
||||
if type(a) is type(b) and isinstance(a, (str, bool)):
|
||||
if not _safe_equal(a, b):
|
||||
out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None})
|
||||
return
|
||||
|
||||
# lists
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
if len(a) != len(b):
|
||||
out.append(
|
||||
{
|
||||
"path": f"{path or '$'}.length",
|
||||
"left": len(a),
|
||||
"right": len(b),
|
||||
"abs_delta": None,
|
||||
}
|
||||
)
|
||||
n = min(len(a), len(b))
|
||||
for i in range(n):
|
||||
_recursive_leaf_diffs(
|
||||
a[i], b[i], f"{path or '$'}[{i}]", out, float_atol, float_rtol
|
||||
)
|
||||
for i in range(n, len(a)):
|
||||
out.append(
|
||||
{
|
||||
"path": f"{path or '$'}[{i}]",
|
||||
"left": a[i],
|
||||
"right": None,
|
||||
"abs_delta": None,
|
||||
}
|
||||
)
|
||||
for i in range(n, len(b)):
|
||||
out.append(
|
||||
{
|
||||
"path": f"{path or '$'}[{i}]",
|
||||
"left": None,
|
||||
"right": b[i],
|
||||
"abs_delta": None,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# dicts
|
||||
if isinstance(a, dict) and isinstance(b, dict):
|
||||
keys = sorted(set(a.keys()) | set(b.keys()))
|
||||
for k in keys:
|
||||
ak = a.get(k, None)
|
||||
bk = b.get(k, None)
|
||||
if k not in a:
|
||||
out.append(
|
||||
{
|
||||
"path": f"{path or '$'}.{k}",
|
||||
"left": None,
|
||||
"right": bk,
|
||||
"abs_delta": None,
|
||||
}
|
||||
)
|
||||
elif k not in b:
|
||||
out.append(
|
||||
{
|
||||
"path": f"{path or '$'}.{k}",
|
||||
"left": ak,
|
||||
"right": None,
|
||||
"abs_delta": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
_recursive_leaf_diffs(
|
||||
ak, bk, f"{path or '$'}.{k}", out, float_atol, float_rtol
|
||||
)
|
||||
return
|
||||
|
||||
# fallback (type mismatch / opaque objects)
|
||||
if not _safe_equal(a, b):
|
||||
out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None})
|
||||
|
||||
|
||||
def _boolean_mask_simple_equals(s1: pl.Series, s2: pl.Series) -> pl.Series:
|
||||
both_null = s1.is_null() & s2.is_null()
|
||||
return ((s1 == s2) | both_null).fill_null(True)
|
||||
|
||||
|
||||
def _boolean_mask_float_close(
|
||||
s1: pl.Series, s2: pl.Series, atol: float, rtol: float
|
||||
) -> pl.Series:
|
||||
both_null = s1.is_null() & s2.is_null()
|
||||
both_nan = s1.is_nan() & s2.is_nan()
|
||||
abs_diff = (s1 - s2).abs()
|
||||
near = abs_diff <= (atol + rtol * s2.abs())
|
||||
return (near | both_null | both_nan).fill_null(False)
|
||||
|
||||
|
||||
def _candidate_rows_for_nested(col_left: pl.Series, col_right: pl.Series) -> List[int]:
|
||||
"""
|
||||
Cheap way to find rows that might differ for nested types:
|
||||
compare JSON dumps of values. This is only a prefilter.
|
||||
"""
|
||||
a = col_left.to_list()
|
||||
b = col_right.to_list()
|
||||
cand = []
|
||||
for i, (x, y) in enumerate(zip(a, b)):
|
||||
if _repr_safe(x) != _repr_safe(y):
|
||||
cand.append(i)
|
||||
return cand
|
||||
|
||||
|
||||
def recursive_diff_frames(
|
||||
left: pl.DataFrame,
|
||||
right: pl.DataFrame,
|
||||
ignore: Optional[List[str]] = None,
|
||||
float_atol: float = 0.0,
|
||||
float_rtol: float = 0.0,
|
||||
max_rows_per_column: int = 20,
|
||||
max_leafs_per_row: int = 200,
|
||||
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
"""
|
||||
Deep diff DataFrames, recursing into List/Struct/dict-like values.
|
||||
Returns (diff_summary, diff_leaves).
|
||||
- diff_summary: [column, n_rows_with_diffs]
|
||||
- diff_leaves: [column, row, path, left, right, abs_delta]
|
||||
left/right are Python values (JSON-serializable where possible).
|
||||
"""
|
||||
ignore = set(ignore or [])
|
||||
|
||||
# basic guards
|
||||
if left.height != right.height:
|
||||
raise ValueError(f"Row count differs: {left.height} vs {right.height}")
|
||||
|
||||
lcols = set(left.columns) - ignore
|
||||
rcols = set(right.columns) - ignore
|
||||
if lcols != rcols:
|
||||
raise ValueError(
|
||||
f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}"
|
||||
)
|
||||
|
||||
cols = sorted(lcols)
|
||||
|
||||
summary_rows: List[Tuple[str, int]] = []
|
||||
leaves_rows: List[Dict[str, Any]] = []
|
||||
|
||||
for c in cols:
|
||||
s1, s2 = left[c], right[c]
|
||||
|
||||
# Fast path for simple, non-nested types with vectorized comparison
|
||||
simple_dtype = (
|
||||
s1.dtype in SIMPLE_CASTABLE_DTYPES and s2.dtype in SIMPLE_CASTABLE_DTYPES
|
||||
)
|
||||
is_floaty = s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES
|
||||
|
||||
if simple_dtype and not is_floaty:
|
||||
equal_mask = _boolean_mask_simple_equals(s1, s2)
|
||||
diff_idx = [i for i, ok in enumerate(equal_mask) if not ok]
|
||||
elif simple_dtype and is_floaty:
|
||||
close_mask = _boolean_mask_float_close(s1, s2, float_atol, float_rtol)
|
||||
diff_idx = [i for i, ok in enumerate(close_mask) if not ok]
|
||||
else:
|
||||
# nested or exotic dtype → candidate rows via JSON compare
|
||||
diff_idx = _candidate_rows_for_nested(s1, s2)
|
||||
|
||||
if not diff_idx:
|
||||
continue
|
||||
|
||||
summary_rows.append((c, len(diff_idx)))
|
||||
|
||||
# limit how many rows per column we fully expand
|
||||
for row in diff_idx[:max_rows_per_column]:
|
||||
a = s1[row]
|
||||
b = s2[row]
|
||||
leaf_diffs: List[Dict[str, Any]] = []
|
||||
_recursive_leaf_diffs(
|
||||
a,
|
||||
b,
|
||||
path="",
|
||||
out=leaf_diffs,
|
||||
float_atol=float_atol,
|
||||
float_rtol=float_rtol,
|
||||
)
|
||||
|
||||
# If all leaf_diffs are only float-close (within tol), suppress (can happen for nested)
|
||||
# The recursive function already filters by tolerance for numbers, so we keep what's left.
|
||||
|
||||
# cap the number of leaf diffs to avoid explosion
|
||||
for d in leaf_diffs[:max_leafs_per_row]:
|
||||
left_norm = _repr_safe(_to_python(d["left"])) # -> str
|
||||
right_norm = _repr_safe(_to_python(d["right"])) # -> str
|
||||
|
||||
abs_delta_val = d.get("abs_delta", None)
|
||||
try:
|
||||
abs_delta_norm = (
|
||||
float(abs_delta_val) if abs_delta_val is not None else None
|
||||
)
|
||||
except Exception:
|
||||
abs_delta_norm = None # just in case something weird sneaks in
|
||||
|
||||
leaves_rows.append(
|
||||
{
|
||||
"column": str(c),
|
||||
"row": int(row),
|
||||
"path": str(d["path"] or "$"),
|
||||
"left": left_norm, # str
|
||||
"right": right_norm, # str
|
||||
"abs_delta": abs_delta_norm, # float or None
|
||||
}
|
||||
)
|
||||
|
||||
diff_summary = (
|
||||
pl.DataFrame(summary_rows, schema=["column", "n_rows_with_diffs"]).sort(
|
||||
"n_rows_with_diffs", descending=True
|
||||
)
|
||||
if summary_rows
|
||||
else pl.DataFrame(
|
||||
{
|
||||
"column": pl.Series([], pl.Utf8),
|
||||
"n_rows_with_diffs": pl.Series([], pl.Int64),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Build diff_leaves with stable schema; stringify complex left/right to avoid concat issues
|
||||
if leaves_rows:
|
||||
diff_leaves = pl.DataFrame(
|
||||
{
|
||||
"column": [r["column"] for r in leaves_rows],
|
||||
"row": pl.Series([r["row"] for r in leaves_rows], dtype=pl.Int64),
|
||||
"path": [r["path"] for r in leaves_rows],
|
||||
"left": [r["left"] for r in leaves_rows], # Utf8
|
||||
"right": [r["right"] for r in leaves_rows], # Utf8
|
||||
"abs_delta": pl.Series(
|
||||
[r["abs_delta"] for r in leaves_rows], dtype=pl.Float64
|
||||
),
|
||||
},
|
||||
schema={
|
||||
"column": pl.Utf8,
|
||||
"row": pl.Int64,
|
||||
"path": pl.Utf8,
|
||||
"left": pl.Utf8,
|
||||
"right": pl.Utf8,
|
||||
"abs_delta": pl.Float64,
|
||||
},
|
||||
)
|
||||
else:
|
||||
diff_leaves = pl.DataFrame(
|
||||
{
|
||||
"column": [],
|
||||
"row": [],
|
||||
"path": [],
|
||||
"left": [],
|
||||
"right": [],
|
||||
"abs_delta": [],
|
||||
}
|
||||
)
|
||||
|
||||
return diff_summary, diff_leaves
|
||||
|
||||
# FLOAT_DTYPES = {pl.Float32, pl.Float64}
|
||||
|
||||
# def diff_frames(
|
||||
# left: pl.DataFrame,
|
||||
# right: pl.DataFrame,
|
||||
# ignore: Optional[List[str]] = None,
|
||||
# float_atol: float = 0.0,
|
||||
# float_rtol: float = 0.0,
|
||||
# sample: int = 20,
|
||||
# ) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
# ignore = set(ignore or [])
|
||||
|
||||
# if left.height != right.height:
|
||||
# raise ValueError(f"Row count differs: {left.height} vs {right.height}")
|
||||
|
||||
# lcols = set(left.columns) - ignore
|
||||
# rcols = set(right.columns) - ignore
|
||||
# if lcols != rcols:
|
||||
# raise ValueError(
|
||||
# f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}"
|
||||
# )
|
||||
|
||||
# cols = sorted(lcols)
|
||||
# row_idx = pl.Series("row", range(left.height), dtype=pl.Int64)
|
||||
|
||||
# def _float_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series:
|
||||
# both_null = s1.is_null() & s2.is_null()
|
||||
# both_nan = s1.is_nan() & s2.is_nan()
|
||||
# abs_diff = (s1 - s2).abs()
|
||||
# near = abs_diff <= (float_atol + float_rtol * s2.abs())
|
||||
# return ~(near | both_null | both_nan)
|
||||
|
||||
# def _nonfloat_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series:
|
||||
# both_null = s1.is_null() & s2.is_null()
|
||||
# return ~((s1 == s2) | both_null).fill_null(True)
|
||||
|
||||
# examples_frames = []
|
||||
# summary_rows = []
|
||||
|
||||
# for c in cols:
|
||||
# s1, s2 = left[c], right[c]
|
||||
# if s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES:
|
||||
# diff_mask = _float_diff_mask(s1, s2)
|
||||
# abs_delta = (s1 - s2).abs()
|
||||
# else:
|
||||
# diff_mask = _nonfloat_diff_mask(s1, s2)
|
||||
# abs_delta = None
|
||||
|
||||
# diff_mask = diff_mask.cast(pl.Boolean)
|
||||
# n_diff = int(diff_mask.sum())
|
||||
# if n_diff == 0:
|
||||
# continue
|
||||
|
||||
# summary_rows.append((c, n_diff))
|
||||
# k = min(sample, n_diff)
|
||||
|
||||
# idx = row_idx.filter(diff_mask)[:k]
|
||||
|
||||
# def to_utf8_safe(s: pl.Series) -> pl.Series:
|
||||
# # Fast path for simple scalars
|
||||
# if s.dtype in (
|
||||
# pl.Int8,
|
||||
# pl.Int16,
|
||||
# pl.Int32,
|
||||
# pl.Int64,
|
||||
# pl.UInt8,
|
||||
# pl.UInt16,
|
||||
# pl.UInt32,
|
||||
# pl.UInt64,
|
||||
# pl.Float32,
|
||||
# pl.Float64,
|
||||
# pl.Utf8,
|
||||
# pl.Boolean,
|
||||
# pl.Date,
|
||||
# pl.Datetime,
|
||||
# pl.Time,
|
||||
# pl.Duration,
|
||||
# ):
|
||||
# return s.cast(pl.Utf8)
|
||||
# # Fallback for nested/complex types: List, Struct, etc.
|
||||
# return s.map_elements(
|
||||
# lambda v: json.dumps(v, default=str, allow_nan=True),
|
||||
# return_dtype=pl.Utf8,
|
||||
# )
|
||||
|
||||
# ex_left = to_utf8_safe(s1.filter(diff_mask)[:k])
|
||||
# ex_right = to_utf8_safe(s2.filter(diff_mask)[:k])
|
||||
|
||||
# ex = pl.DataFrame(
|
||||
# {
|
||||
# "column": [c] * k,
|
||||
# "row": idx,
|
||||
# "left": ex_left,
|
||||
# "right": ex_right,
|
||||
# "dtype_left": [str(s1.dtype)] * k,
|
||||
# "dtype_right": [str(s2.dtype)] * k,
|
||||
# }
|
||||
# )
|
||||
|
||||
# # unify schema: always have abs_delta as Float64 (None for non-floats)
|
||||
# if abs_delta is not None:
|
||||
# ex = ex.with_columns(
|
||||
# abs_delta.filter(diff_mask)[:k].cast(pl.Float64).alias("abs_delta")
|
||||
# )
|
||||
# else:
|
||||
# ex = ex.with_columns(pl.lit(None, dtype=pl.Float64).alias("abs_delta"))
|
||||
|
||||
# examples_frames.append(ex)
|
||||
|
||||
# diff_summary = (
|
||||
# pl.DataFrame(summary_rows, schema=["column", "n_different"]).sort(
|
||||
# "n_different", descending=True
|
||||
# )
|
||||
# if summary_rows
|
||||
# else pl.DataFrame(
|
||||
# {
|
||||
# "column": pl.Series([], pl.Utf8),
|
||||
# "n_different": pl.Series([], pl.Int64),
|
||||
# }
|
||||
# )
|
||||
# )
|
||||
# diff_examples = (
|
||||
# pl.concat(examples_frames) if examples_frames else pl.DataFrame()
|
||||
# )
|
||||
|
||||
# return diff_summary, diff_examples
|
||||
|
||||
# # --- usage ---
|
||||
# # diff_summary: one row per column with a count of differing rows
|
||||
# # diff_examples: sample rows showing left/right values (and abs_delta for floats)
|
||||
# summary, examples = diff_frames(
|
||||
# df1, df2, ignore=["timestamp"], float_atol=0.1, float_rtol=0.0, sample=25
|
||||
# )
|
||||
|
||||
# print(summary) # which columns differ and how much
|
||||
# print(examples) # sample mismatches with row indices
|
||||
Reference in New Issue
Block a user