import json import math from typing import Any, Dict, Iterable, List, Optional, Tuple import polars as pl Number = (int, float) FLOAT_DTYPES = {pl.Float32, pl.Float64} SIMPLE_CASTABLE_DTYPES = ( pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Float32, pl.Float64, pl.Utf8, pl.Boolean, pl.Date, pl.Datetime, pl.Time, pl.Duration, ) def _is_nan(x): try: return isinstance(x, float) and math.isnan(x) except Exception: return False def _repr_safe(v): try: return json.dumps(v, default=str, ensure_ascii=False) except Exception: return repr(v) def _to_python(v): """ Convert any leaf-ish object to plain Python types: - pl.Series -> list (or scalar if length==1) - objects with .to_list()/.tolist() -> list - dict stays dict; list/tuple become list """ # Polars Series if isinstance(v, pl.Series): seq = v.to_list() return seq[0] if len(seq) == 1 else seq # Numpy scalars/arrays or anything with tolist() if hasattr(v, "tolist"): try: return v.tolist() except Exception: pass # Polars expressions should not appear; stringify them # Anything iterable that isn't list/dict/str -> convert carefully if isinstance(v, tuple): return [_to_python(x) for x in v] if isinstance(v, list): return [_to_python(x) for x in v] if isinstance(v, dict): return {k: _to_python(val) for k, val in v.items()} return v def _safe_equal(a, b): """ Return a plain bool saying whether a and b are equal, without ever producing a vector/Series. """ # exact same object if a is b: return True # normalize a_n = _to_python(a) b_n = _to_python(b) # handle NaNs if _is_nan(a_n) and _is_nan(b_n): return True # plain scalars/containers try: eq = a_n == b_n if isinstance(eq, bool): return eq except Exception: pass # fallback: compare stable JSON-ish reprs return _repr_safe(a_n) == _repr_safe(b_n) def _num_close(a: float, b: float, atol: float, rtol: float) -> bool: # NaN==NaN treated equal if _is_nan(a) and _is_nan(b): return True return abs(a - b) <= (atol + rtol * abs(b)) def _to_python(v: Any) -> Any: """ Convert Polars value to a Python object. Struct -> dict, List -> list, scalars stay scalars. Values coming from Series[i] / .to_list() are already Python, so this usually no-ops. """ return v def _repr_safe(v: Any) -> str: try: return json.dumps(v, default=str, ensure_ascii=False) except Exception: return repr(v) def _iter_dict_keys(d: Dict[str, Any]) -> Iterable[str]: # stable order, useful for predictable output return sorted(d.keys()) def _recursive_leaf_diffs(a, b, path, out, float_atol, float_rtol): # treat None==None if a is None and b is None: return # normalize early a = _to_python(a) b = _to_python(b) # tuples -> lists if isinstance(a, tuple): a = list(a) if isinstance(b, tuple): b = list(b) # numbers if isinstance(a, (int, float)) and isinstance(b, (int, float)): if _is_nan(a) and _is_nan(b): return # |a-b| <= atol + rtol*|b| if abs(float(a) - float(b)) > (float_atol + float_rtol * abs(float(b))): out.append( { "path": path or "$", "left": a, "right": b, "abs_delta": abs(float(a) - float(b)), } ) return # exact types for strings/bools if type(a) is type(b) and isinstance(a, (str, bool)): if not _safe_equal(a, b): out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None}) return # lists if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): out.append( { "path": f"{path or '$'}.length", "left": len(a), "right": len(b), "abs_delta": None, } ) n = min(len(a), len(b)) for i in range(n): _recursive_leaf_diffs( a[i], b[i], f"{path or '$'}[{i}]", out, float_atol, float_rtol ) for i in range(n, len(a)): out.append( { "path": f"{path or '$'}[{i}]", "left": a[i], "right": None, "abs_delta": None, } ) for i in range(n, len(b)): out.append( { "path": f"{path or '$'}[{i}]", "left": None, "right": b[i], "abs_delta": None, } ) return # dicts if isinstance(a, dict) and isinstance(b, dict): keys = sorted(set(a.keys()) | set(b.keys())) for k in keys: ak = a.get(k, None) bk = b.get(k, None) if k not in a: out.append( { "path": f"{path or '$'}.{k}", "left": None, "right": bk, "abs_delta": None, } ) elif k not in b: out.append( { "path": f"{path or '$'}.{k}", "left": ak, "right": None, "abs_delta": None, } ) else: _recursive_leaf_diffs( ak, bk, f"{path or '$'}.{k}", out, float_atol, float_rtol ) return # fallback (type mismatch / opaque objects) if not _safe_equal(a, b): out.append({"path": path or "$", "left": a, "right": b, "abs_delta": None}) def _boolean_mask_simple_equals(s1: pl.Series, s2: pl.Series) -> pl.Series: both_null = s1.is_null() & s2.is_null() return ((s1 == s2) | both_null).fill_null(True) def _boolean_mask_float_close( s1: pl.Series, s2: pl.Series, atol: float, rtol: float ) -> pl.Series: both_null = s1.is_null() & s2.is_null() both_nan = s1.is_nan() & s2.is_nan() abs_diff = (s1 - s2).abs() near = abs_diff <= (atol + rtol * s2.abs()) return (near | both_null | both_nan).fill_null(False) def _candidate_rows_for_nested(col_left: pl.Series, col_right: pl.Series) -> List[int]: """ Cheap way to find rows that might differ for nested types: compare JSON dumps of values. This is only a prefilter. """ a = col_left.to_list() b = col_right.to_list() cand = [] for i, (x, y) in enumerate(zip(a, b)): if _repr_safe(x) != _repr_safe(y): cand.append(i) return cand def recursive_diff_frames( left: pl.DataFrame, right: pl.DataFrame, ignore: Optional[List[str]] = None, float_atol: float = 0.0, float_rtol: float = 0.0, max_rows_per_column: int = 20, max_leafs_per_row: int = 200, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """ Deep diff DataFrames, recursing into List/Struct/dict-like values. Returns (diff_summary, diff_leaves). - diff_summary: [column, n_rows_with_diffs] - diff_leaves: [column, row, path, left, right, abs_delta] left/right are Python values (JSON-serializable where possible). """ ignore = set(ignore or []) # basic guards if left.height != right.height: raise ValueError(f"Row count differs: {left.height} vs {right.height}") lcols = set(left.columns) - ignore rcols = set(right.columns) - ignore if lcols != rcols: raise ValueError( f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}" ) cols = sorted(lcols) summary_rows: List[Tuple[str, int]] = [] leaves_rows: List[Dict[str, Any]] = [] for c in cols: s1, s2 = left[c], right[c] # Fast path for simple, non-nested types with vectorized comparison simple_dtype = ( s1.dtype in SIMPLE_CASTABLE_DTYPES and s2.dtype in SIMPLE_CASTABLE_DTYPES ) is_floaty = s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES if simple_dtype and not is_floaty: equal_mask = _boolean_mask_simple_equals(s1, s2) diff_idx = [i for i, ok in enumerate(equal_mask) if not ok] elif simple_dtype and is_floaty: close_mask = _boolean_mask_float_close(s1, s2, float_atol, float_rtol) diff_idx = [i for i, ok in enumerate(close_mask) if not ok] else: # nested or exotic dtype → candidate rows via JSON compare diff_idx = _candidate_rows_for_nested(s1, s2) if not diff_idx: continue summary_rows.append((c, len(diff_idx))) # limit how many rows per column we fully expand for row in diff_idx[:max_rows_per_column]: a = s1[row] b = s2[row] leaf_diffs: List[Dict[str, Any]] = [] _recursive_leaf_diffs( a, b, path="", out=leaf_diffs, float_atol=float_atol, float_rtol=float_rtol, ) # If all leaf_diffs are only float-close (within tol), suppress (can happen for nested) # The recursive function already filters by tolerance for numbers, so we keep what's left. # cap the number of leaf diffs to avoid explosion for d in leaf_diffs[:max_leafs_per_row]: left_norm = _repr_safe(_to_python(d["left"])) # -> str right_norm = _repr_safe(_to_python(d["right"])) # -> str abs_delta_val = d.get("abs_delta", None) try: abs_delta_norm = ( float(abs_delta_val) if abs_delta_val is not None else None ) except Exception: abs_delta_norm = None # just in case something weird sneaks in leaves_rows.append( { "column": str(c), "row": int(row), "path": str(d["path"] or "$"), "left": left_norm, # str "right": right_norm, # str "abs_delta": abs_delta_norm, # float or None } ) diff_summary = ( pl.DataFrame(summary_rows, schema=["column", "n_rows_with_diffs"]).sort( "n_rows_with_diffs", descending=True ) if summary_rows else pl.DataFrame( { "column": pl.Series([], pl.Utf8), "n_rows_with_diffs": pl.Series([], pl.Int64), } ) ) # Build diff_leaves with stable schema; stringify complex left/right to avoid concat issues if leaves_rows: diff_leaves = pl.DataFrame( { "column": [r["column"] for r in leaves_rows], "row": pl.Series([r["row"] for r in leaves_rows], dtype=pl.Int64), "path": [r["path"] for r in leaves_rows], "left": [r["left"] for r in leaves_rows], # Utf8 "right": [r["right"] for r in leaves_rows], # Utf8 "abs_delta": pl.Series( [r["abs_delta"] for r in leaves_rows], dtype=pl.Float64 ), }, schema={ "column": pl.Utf8, "row": pl.Int64, "path": pl.Utf8, "left": pl.Utf8, "right": pl.Utf8, "abs_delta": pl.Float64, }, ) else: diff_leaves = pl.DataFrame( { "column": [], "row": [], "path": [], "left": [], "right": [], "abs_delta": [], } ) return diff_summary, diff_leaves # FLOAT_DTYPES = {pl.Float32, pl.Float64} # def diff_frames( # left: pl.DataFrame, # right: pl.DataFrame, # ignore: Optional[List[str]] = None, # float_atol: float = 0.0, # float_rtol: float = 0.0, # sample: int = 20, # ) -> Tuple[pl.DataFrame, pl.DataFrame]: # ignore = set(ignore or []) # if left.height != right.height: # raise ValueError(f"Row count differs: {left.height} vs {right.height}") # lcols = set(left.columns) - ignore # rcols = set(right.columns) - ignore # if lcols != rcols: # raise ValueError( # f"Column sets differ after ignoring.\nleft_only={sorted(lcols - rcols)}\nright_only={sorted(rcols - lcols)}" # ) # cols = sorted(lcols) # row_idx = pl.Series("row", range(left.height), dtype=pl.Int64) # def _float_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series: # both_null = s1.is_null() & s2.is_null() # both_nan = s1.is_nan() & s2.is_nan() # abs_diff = (s1 - s2).abs() # near = abs_diff <= (float_atol + float_rtol * s2.abs()) # return ~(near | both_null | both_nan) # def _nonfloat_diff_mask(s1: pl.Series, s2: pl.Series) -> pl.Series: # both_null = s1.is_null() & s2.is_null() # return ~((s1 == s2) | both_null).fill_null(True) # examples_frames = [] # summary_rows = [] # for c in cols: # s1, s2 = left[c], right[c] # if s1.dtype in FLOAT_DTYPES and s2.dtype in FLOAT_DTYPES: # diff_mask = _float_diff_mask(s1, s2) # abs_delta = (s1 - s2).abs() # else: # diff_mask = _nonfloat_diff_mask(s1, s2) # abs_delta = None # diff_mask = diff_mask.cast(pl.Boolean) # n_diff = int(diff_mask.sum()) # if n_diff == 0: # continue # summary_rows.append((c, n_diff)) # k = min(sample, n_diff) # idx = row_idx.filter(diff_mask)[:k] # def to_utf8_safe(s: pl.Series) -> pl.Series: # # Fast path for simple scalars # if s.dtype in ( # pl.Int8, # pl.Int16, # pl.Int32, # pl.Int64, # pl.UInt8, # pl.UInt16, # pl.UInt32, # pl.UInt64, # pl.Float32, # pl.Float64, # pl.Utf8, # pl.Boolean, # pl.Date, # pl.Datetime, # pl.Time, # pl.Duration, # ): # return s.cast(pl.Utf8) # # Fallback for nested/complex types: List, Struct, etc. # return s.map_elements( # lambda v: json.dumps(v, default=str, allow_nan=True), # return_dtype=pl.Utf8, # ) # ex_left = to_utf8_safe(s1.filter(diff_mask)[:k]) # ex_right = to_utf8_safe(s2.filter(diff_mask)[:k]) # ex = pl.DataFrame( # { # "column": [c] * k, # "row": idx, # "left": ex_left, # "right": ex_right, # "dtype_left": [str(s1.dtype)] * k, # "dtype_right": [str(s2.dtype)] * k, # } # ) # # unify schema: always have abs_delta as Float64 (None for non-floats) # if abs_delta is not None: # ex = ex.with_columns( # abs_delta.filter(diff_mask)[:k].cast(pl.Float64).alias("abs_delta") # ) # else: # ex = ex.with_columns(pl.lit(None, dtype=pl.Float64).alias("abs_delta")) # examples_frames.append(ex) # diff_summary = ( # pl.DataFrame(summary_rows, schema=["column", "n_different"]).sort( # "n_different", descending=True # ) # if summary_rows # else pl.DataFrame( # { # "column": pl.Series([], pl.Utf8), # "n_different": pl.Series([], pl.Int64), # } # ) # ) # diff_examples = ( # pl.concat(examples_frames) if examples_frames else pl.DataFrame() # ) # return diff_summary, diff_examples # # --- usage --- # # diff_summary: one row per column with a count of differing rows # # diff_examples: sample rows showing left/right values (and abs_delta for floats) # summary, examples = diff_frames( # df1, df2, ignore=["timestamp"], float_atol=0.1, float_rtol=0.0, sample=25 # ) # print(summary) # which columns differ and how much # print(examples) # sample mismatches with row indices