Files
mt/tools/diff_df.py

545 lines
17 KiB
Python
Raw Normal View History

2025-09-09 14:15:16 +02:00
import json
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple
import polars as pl
Number = (int, float)
FLOAT_DTYPES = {pl.Float32, pl.Float64}
SIMPLE_CASTABLE_DTYPES = (
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
pl.Utf8,
pl.Boolean,
pl.Date,
pl.Datetime,
pl.Time,
pl.Duration,
)
def _is_nan(x):
try:
return isinstance(x, float) and math.isnan(x)
except Exception:
return False
def _repr_safe(v):
try:
return json.dumps(v, default=str, ensure_ascii=False)
except Exception:
return repr(v)
def _to_python(v):
"""
Convert any leaf-ish object to plain Python types:
- pl.Series -> list (or scalar if length==1)
- objects with .to_list()/.tolist() -> list
- dict stays dict; list/tuple become list
"""
# Polars Series
if isinstance(v, pl.Series):
seq = v.to_list()
return seq[0] if len(seq) == 1 else seq
# Numpy scalars/arrays or anything with tolist()
if hasattr(v, "tolist"):
try:
return v.tolist()
except Exception:
pass
# Polars expressions should not appear; stringify them
# Anything iterable that isn't list/dict/str -> convert carefully
if isinstance(v, tuple):
return [_to_python(x) for x in v]
if isinstance(v, list):
return [_to_python(x) for x in v]
if isinstance(v, dict):
return {k: _to_python(val) for k, val in v.items()}
return v
def _safe_equal(a, b):
"""
Return a plain bool saying whether a and b are equal,
without ever producing a vector/Series.
"""
# exact same object
if a is b:
return True
# normalize
a_n = _to_python(a)
b_n = _to_python(b)
# handle NaNs
if _is_nan(a_n) and _is_nan(b_n):
return True
# plain scalars/containers
try:
eq = a_n == b_n
if isinstance(eq, bool):
return eq
except Exception:
pass
# fallback: compare stable JSON-ish reprs
return _repr_safe(a_n) == _repr_safe(b_n)
def _num_close(a: float, b: float, atol: float, rtol: float) -> bool:
# NaN==NaN treated equal
if _is_nan(a) and _is_nan(b):
return True
return abs(a - b) <= (atol + rtol * abs(b))
def _to_python(v: Any) -> Any:
"""
Convert Polars value to a Python object. Struct -> dict, List -> list, scalars stay scalars.
Values coming from Series[i] / .to_list() are already Python, so this usually no-ops.
"""
return v
def _repr_safe(v: Any) -> str:
try:
return json.dumps(v, default=str, ensure_ascii=False)
except Exception:
return repr(v)
def _iter_dict_keys(d: Dict[str, Any]) -> Iterable[str]:
# stable order, useful for predictable output
return sorted(d.keys())
def _recursive_leaf_diffs(a, b, path, out, float_atol, float_rtol):
# treat None==None
if a is None and b is None:
return
# normalize early
a = _to_python(a)
b = _to_python(b)
# tuples -> lists
if isinstance(a, tuple):
a = list(a)
if isinstance(b, tuple):
b = list(b)
# numbers
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
if _is_nan(a) and _is_nan(b):
return
# |a-b| <= atol + rtol*|b|
if abs(float(a) - float(b)) > (float_atol + float_rtol * abs(float(b))):
out.append(
{
"path": path or "$",
"left": a,
"right": b,
"abs_delta": abs(float(a) - float(b)),
}
)
return
# exact types for strings/bools
if type(a) is type(b) and isinstance(a, (str, bool)):
if not _safe_equal(a, b):
out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None})
return
# lists
if isinstance(a, list) and isinstance(b, list):
if len(a) != len(b):
out.append(
{
"path": f"{path or '$'}.length",
"left": len(a),
"right": len(b),
"abs_delta": None,
}
)
n = min(len(a), len(b))
for i in range(n):
_recursive_leaf_diffs(
a[i], b[i], f"{path or '$'}[{i}]", out, float_atol, float_rtol
)
for i in range(n, len(a)):
out.append(
{
"path": f"{path or '$'}[{i}]",
"left": a[i],
"right": None,
"abs_delta": None,
}
)
for i in range(n, len(b)):
out.append(
{
"path": f"{path or '$'}[{i}]",
"left": None,
"right": b[i],
"abs_delta": None,
}
)
return
# dicts
if isinstance(a, dict) and isinstance(b, dict):
keys = sorted(set(a.keys()) | set(b.keys()))
for k in keys:
ak = a.get(k, None)
bk = b.get(k, None)
if k not in a:
out.append(
{
"path": f"{path or '$'}.{k}",
"left": None,
"right": bk,
"abs_delta": None,
}
)
elif k not in b:
out.append(
{
"path": f"{path or '$'}.{k}",
"left": ak,
"right": None,
"abs_delta": None,
}
)
else:
_recursive_leaf_diffs(
ak, bk, f"{path or '$'}.{k}", out, float_atol, float_rtol
)
return
# fallback (type mismatch / opaque objects)
if not _safe_equal(a, b):
out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None})
def _boolean_mask_simple_equals(s1: pl.Series, s2: pl.Series) -> pl.Series:
both_null = s1.is_null() & s2.is_null()
return ((s1 == s2) | both_null).fill_null(True)
def _boolean_mask_float_close(
s1: pl.Series, s2: pl.Series, atol: float, rtol: float
) -> pl.Series:
both_null = s1.is_null() & s2.is_null()
both_nan = s1.is_nan() & s2.is_nan()
abs_diff = (s1 - s2).abs()
near = abs_diff <= (atol + rtol * s2.abs())
return (near | both_null | both_nan).fill_null(False)
def _candidate_rows_for_nested(col_left: pl.Series, col_right: pl.Series) -> List[int]:
"""
Cheap way to find rows that might differ for nested types:
compare JSON dumps of values. This is only a prefilter.
"""
a = col_left.to_list()
b = col_right.to_list()
cand = []
for i, (x, y) in enumerate(zip(a, b)):
if _repr_safe(x) != _repr_safe(y):
cand.append(i)
return cand
def recursive_diff_frames(
left: pl.DataFrame,
right: pl.DataFrame,
ignore: Optional[List[str]] = None,
float_atol: float = 0.0,
float_rtol: float = 0.0,
max_rows_per_column: int = 20,
max_leafs_per_row: int = 200,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""
Deep diff DataFrames, recursing into List/Struct/dict-like values.
Returns (diff_summary, diff_leaves).
- diff_summary: [column, n_rows_with_diffs]
- diff_leaves: [column, row, path, left, right, abs_delta]
left/right are Python values (JSON-serializable where possible).
"""
ignore = set(ignore or [])
# basic guards
if left.height != right.height:
raise ValueError(f"Row count differs: {left.height} vs {right.height}")
lcols = set(left.columns) - ignore
rcols = set(right.columns) - ignore
if lcols != rcols:
raise ValueError(
f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}"
)
cols = sorted(lcols)
summary_rows: List[Tuple[str, int]] = []
leaves_rows: List[Dict[str, Any]] = []
for c in cols:
s1, s2 = left[c], right[c]
# Fast path for simple, non-nested types with vectorized comparison
simple_dtype = (
s1.dtype in SIMPLE_CASTABLE_DTYPES and s2.dtype in SIMPLE_CASTABLE_DTYPES
)
is_floaty = s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES
if simple_dtype and not is_floaty:
equal_mask = _boolean_mask_simple_equals(s1, s2)
diff_idx = [i for i, ok in enumerate(equal_mask) if not ok]
elif simple_dtype and is_floaty:
close_mask = _boolean_mask_float_close(s1, s2, float_atol, float_rtol)
diff_idx = [i for i, ok in enumerate(close_mask) if not ok]
else:
# nested or exotic dtype → candidate rows via JSON compare
diff_idx = _candidate_rows_for_nested(s1, s2)
if not diff_idx:
continue
summary_rows.append((c, len(diff_idx)))
# limit how many rows per column we fully expand
for row in diff_idx[:max_rows_per_column]:
a = s1[row]
b = s2[row]
leaf_diffs: List[Dict[str, Any]] = []
_recursive_leaf_diffs(
a,
b,
path="",
out=leaf_diffs,
float_atol=float_atol,
float_rtol=float_rtol,
)
# If all leaf_diffs are only float-close (within tol), suppress (can happen for nested)
# The recursive function already filters by tolerance for numbers, so we keep what's left.
# cap the number of leaf diffs to avoid explosion
for d in leaf_diffs[:max_leafs_per_row]:
left_norm = _repr_safe(_to_python(d["left"])) # -> str
right_norm = _repr_safe(_to_python(d["right"])) # -> str
abs_delta_val = d.get("abs_delta", None)
try:
abs_delta_norm = (
float(abs_delta_val) if abs_delta_val is not None else None
)
except Exception:
abs_delta_norm = None # just in case something weird sneaks in
leaves_rows.append(
{
"column": str(c),
"row": int(row),
"path": str(d["path"] or "$"),
"left": left_norm, # str
"right": right_norm, # str
"abs_delta": abs_delta_norm, # float or None
}
)
diff_summary = (
pl.DataFrame(summary_rows, schema=["column", "n_rows_with_diffs"]).sort(
"n_rows_with_diffs", descending=True
)
if summary_rows
else pl.DataFrame(
{
"column": pl.Series([], pl.Utf8),
"n_rows_with_diffs": pl.Series([], pl.Int64),
}
)
)
# Build diff_leaves with stable schema; stringify complex left/right to avoid concat issues
if leaves_rows:
diff_leaves = pl.DataFrame(
{
"column": [r["column"] for r in leaves_rows],
"row": pl.Series([r["row"] for r in leaves_rows], dtype=pl.Int64),
"path": [r["path"] for r in leaves_rows],
"left": [r["left"] for r in leaves_rows], # Utf8
"right": [r["right"] for r in leaves_rows], # Utf8
"abs_delta": pl.Series(
[r["abs_delta"] for r in leaves_rows], dtype=pl.Float64
),
},
schema={
"column": pl.Utf8,
"row": pl.Int64,
"path": pl.Utf8,
"left": pl.Utf8,
"right": pl.Utf8,
"abs_delta": pl.Float64,
},
)
else:
diff_leaves = pl.DataFrame(
{
"column": [],
"row": [],
"path": [],
"left": [],
"right": [],
"abs_delta": [],
}
)
return diff_summary, diff_leaves
# FLOAT_DTYPES = {pl.Float32, pl.Float64}
# def diff_frames(
# left: pl.DataFrame,
# right: pl.DataFrame,
# ignore: Optional[List[str]] = None,
# float_atol: float = 0.0,
# float_rtol: float = 0.0,
# sample: int = 20,
# ) -> Tuple[pl.DataFrame, pl.DataFrame]:
# ignore = set(ignore or [])
# if left.height != right.height:
# raise ValueError(f"Row count differs: {left.height} vs {right.height}")
# lcols = set(left.columns) - ignore
# rcols = set(right.columns) - ignore
# if lcols != rcols:
# raise ValueError(
# f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}"
# )
# cols = sorted(lcols)
# row_idx = pl.Series("row", range(left.height), dtype=pl.Int64)
# def _float_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series:
# both_null = s1.is_null() & s2.is_null()
# both_nan = s1.is_nan() & s2.is_nan()
# abs_diff = (s1 - s2).abs()
# near = abs_diff <= (float_atol + float_rtol * s2.abs())
# return ~(near | both_null | both_nan)
# def _nonfloat_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series:
# both_null = s1.is_null() & s2.is_null()
# return ~((s1 == s2) | both_null).fill_null(True)
# examples_frames = []
# summary_rows = []
# for c in cols:
# s1, s2 = left[c], right[c]
# if s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES:
# diff_mask = _float_diff_mask(s1, s2)
# abs_delta = (s1 - s2).abs()
# else:
# diff_mask = _nonfloat_diff_mask(s1, s2)
# abs_delta = None
# diff_mask = diff_mask.cast(pl.Boolean)
# n_diff = int(diff_mask.sum())
# if n_diff == 0:
# continue
# summary_rows.append((c, n_diff))
# k = min(sample, n_diff)
# idx = row_idx.filter(diff_mask)[:k]
# def to_utf8_safe(s: pl.Series) -> pl.Series:
# # Fast path for simple scalars
# if s.dtype in (
# pl.Int8,
# pl.Int16,
# pl.Int32,
# pl.Int64,
# pl.UInt8,
# pl.UInt16,
# pl.UInt32,
# pl.UInt64,
# pl.Float32,
# pl.Float64,
# pl.Utf8,
# pl.Boolean,
# pl.Date,
# pl.Datetime,
# pl.Time,
# pl.Duration,
# ):
# return s.cast(pl.Utf8)
# # Fallback for nested/complex types: List, Struct, etc.
# return s.map_elements(
# lambda v: json.dumps(v, default=str, allow_nan=True),
# return_dtype=pl.Utf8,
# )
# ex_left = to_utf8_safe(s1.filter(diff_mask)[:k])
# ex_right = to_utf8_safe(s2.filter(diff_mask)[:k])
# ex = pl.DataFrame(
# {
# "column": [c] * k,
# "row": idx,
# "left": ex_left,
# "right": ex_right,
# "dtype_left": [str(s1.dtype)] * k,
# "dtype_right": [str(s2.dtype)] * k,
# }
# )
# # unify schema: always have abs_delta as Float64 (None for non-floats)
# if abs_delta is not None:
# ex = ex.with_columns(
# abs_delta.filter(diff_mask)[:k].cast(pl.Float64).alias("abs_delta")
# )
# else:
# ex = ex.with_columns(pl.lit(None, dtype=pl.Float64).alias("abs_delta"))
# examples_frames.append(ex)
# diff_summary = (
# pl.DataFrame(summary_rows, schema=["column", "n_different"]).sort(
# "n_different", descending=True
# )
# if summary_rows
# else pl.DataFrame(
# {
# "column": pl.Series([], pl.Utf8),
# "n_different": pl.Series([], pl.Int64),
# }
# )
# )
# diff_examples = (
# pl.concat(examples_frames) if examples_frames else pl.DataFrame()
# )
# return diff_summary, diff_examples
# # --- usage ---
# # diff_summary: one row per column with a count of differing rows
# # diff_examples: sample rows showing left/right values (and abs_delta for floats)
# summary, examples = diff_frames(
# df1, df2, ignore=["timestamp"], float_atol=0.1, float_rtol=0.0, sample=25
# )
# print(summary) # which columns differ and how much
# print(examples) # sample mismatches with row indices