Files
mt/tools/compare_projections.py

206 lines
6.9 KiB
Python
Raw Normal View History

2025-08-13 14:17:12 +02:00
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import numpy as np
# Configuration
old_projections_path = Path("/home/fedex/mt/data/subter")
new_projections_path = Path("/home/fedex/mt/data/subter/new_projection")
def get_file_info(file_path: Path) -> Optional[dict]:
"""Get detailed information about a .npy file."""
if not file_path.exists():
return None
try:
with file_path.open("rb") as f:
return {
"size": file_path.stat().st_size,
"format": np.lib.format.read_magic(f),
"header": np.lib.format.read_array_header_1_0(f),
}
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def get_array_info(arr: np.ndarray) -> dict:
"""Get detailed information about a numpy array."""
return {
"dtype": str(arr.dtype),
"itemsize": arr.itemsize,
"nbytes": arr.nbytes,
"byteorder": arr.dtype.byteorder,
"shape": arr.shape,
"nan_count": np.sum(np.isnan(arr)),
"inf_count": np.sum(np.isinf(arr)),
}
def compare_arrays(old_arr: np.ndarray, new_arr: np.ndarray) -> dict:
"""Compare two numpy arrays and return detailed differences."""
# Get basic array information
old_info = get_array_info(old_arr)
new_info = get_array_info(new_arr)
differences = {
"old_info": old_info,
"new_info": new_info,
"shape_mismatch": old_arr.shape != new_arr.shape,
}
if differences["shape_mismatch"]:
return differences
# Compare values
both_nan = np.isnan(old_arr) & np.isnan(new_arr)
unequal = (old_arr != new_arr) & ~both_nan
mismatch_indices = np.where(unequal)
differences.update(
{
"num_mismatches": mismatch_indices[0].size,
"mismatch_indices": mismatch_indices
if mismatch_indices[0].size > 0
else None,
"mean_difference": np.mean(np.abs(old_arr - new_arr)),
"std_difference": np.std(old_arr - new_arr),
"new_zeros": np.sum(new_arr == 0),
}
)
# If there are mismatches, store some example values
if differences["num_mismatches"] > 0:
old_values = old_arr[mismatch_indices][:10]
new_values = new_arr[mismatch_indices][:10]
differences.update(
{
"example_old_values": old_values,
"example_new_values": new_values,
"all_new_mismatches_zero": np.all(new_arr[mismatch_indices] == 0),
}
)
return differences
def print_detailed_comparison(name: str, diff: dict):
"""Print detailed comparison results for a single file."""
print(f"\nDetailed comparison for: {name}")
print("=" * 80)
# Storage information
old_info = diff["old_info"]
new_info = diff["new_info"]
print(f"Storage Information:")
print(f" Dtype: {old_info['dtype']}{new_info['dtype']}")
print(f" Item size: {old_info['itemsize']}{new_info['itemsize']} bytes")
print(f" Total bytes: {old_info['nbytes']}{new_info['nbytes']}")
print(f" Byte ratio: {new_info['nbytes'] / old_info['nbytes']:.2f}")
# Shape information
if diff["shape_mismatch"]:
print(f"Shape mismatch: {old_info['shape']}{new_info['shape']}")
return
# Value differences
if diff["num_mismatches"] > 0:
print(f"\nValue Differences:")
print(f" Number of mismatches: {diff['num_mismatches']}")
print(f" Mean difference: {diff['mean_difference']:.2e}")
print(f" Std difference: {diff['std_difference']:.2e}")
print(f" Example mismatches (old → new):")
for old_val, new_val in zip(
diff["example_old_values"], diff["example_new_values"]
):
print(f" {old_val:.6e}{new_val:.6e}")
if diff["all_new_mismatches_zero"]:
print(" Note: All mismatched values in new array are zero")
else:
print("\nNo value mismatches found.")
def summarize_differences(all_differences: Dict[str, dict]) -> str:
"""Create a summary of all differences."""
summary = []
summary.append("\nSUMMARY OF ALL DIFFERENCES")
summary.append("=" * 80)
total_files = len(all_differences)
files_with_mismatches = sum(
1 for d in all_differences.values() if d["num_mismatches"] > 0
)
files_with_shape_mismatch = sum(
1 for d in all_differences.values() if d["shape_mismatch"]
)
summary.append(f"Total files compared: {total_files}")
summary.append(f"Files with shape mismatches: {files_with_shape_mismatch}")
summary.append(f"Files with value mismatches: {files_with_mismatches}")
if files_with_mismatches > 0:
summary.append("\nFiles with differences:")
for name, diff in all_differences.items():
if diff["num_mismatches"] > 0:
summary.append(f" {name}:")
summary.append(f" Mismatches: {diff['num_mismatches']}")
summary.append(f" Mean difference: {diff['mean_difference']:.2e}")
if diff.get("all_new_mismatches_zero", False):
summary.append(" All mismatches are zeros in new file")
return "\n".join(summary)
def main():
# Get list of all .npy files
old_files = list(old_projections_path.glob("*.npy"))
new_files = list(new_projections_path.glob("*.npy"))
# Check for missing files
old_names = {f.stem for f in old_files}
new_names = {f.stem for f in new_files}
if missing := (old_names - new_names):
print(f"Files missing in new directory: {missing}")
if missing := (new_names - old_names):
print(f"Files missing in old directory: {missing}")
# Compare common files
all_differences = {}
for old_file in old_files:
if old_file.stem not in new_names:
continue
print(f"\nComparing {old_file.stem}...")
new_file = new_projections_path / f"{old_file.stem}.npy"
# Check file info before loading
old_info = get_file_info(old_file)
new_info = get_file_info(new_file)
if not old_info or not new_info:
print(f"Skipping {old_file.stem} due to file reading errors")
continue
try:
# Load arrays
old_arr = np.load(old_file)
new_arr = np.load(new_file)
# Compare and print detailed results
differences = compare_arrays(old_arr, new_arr)
print_detailed_comparison(old_file.stem, differences)
all_differences[old_file.stem] = differences
except Exception as e:
print(f"Error processing {old_file.stem}: {e}")
continue
# Print summary
if all_differences:
print(summarize_differences(all_differences))
else:
print("\nNo valid comparisons were made.")
if __name__ == "__main__":
main()