Files
mt/tools/print_results_structure.py

275 lines
8.3 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import json
import pickle
from pathlib import Path
from typing import Any, Dict, List, Set
import numpy as np
# --- CONFIG ---
ROOT = Path("/home/fedex/mt/results/done") # <- adjust if needed
# MODELS = ["deepsad", "isoforest", "ocsvm"]
MODELS = ["ae"]
# How much to show for very large collections
MAX_KEYS = 100 # show up to this many dict keys explicitly
MAX_GROUPS = 10 # distinct element-structure groups to print for a sequence
SAMPLE_PER_GROUP = 1 # recurse into this many representative elements per group
INDENT = " "
# ---------- Signature helpers ----------
def one_line_sig(obj: Any) -> str:
"""Single-line structural signature for grouping & summary."""
t = type(obj)
tn = t.__name__
# numpy arrays
if isinstance(obj, np.ndarray):
return f"ndarray(shape={tuple(obj.shape)}, dtype={obj.dtype})"
# numpy scalars
if isinstance(obj, (np.generic,)):
return f"{tn}"
# scalars / strings / None
if isinstance(obj, (int, float, bool, str, type(None))):
return tn
# dict
if isinstance(obj, dict):
# do not expand values here; just list key count
return f"dict(len={len(obj)})"
# list / tuple
if isinstance(obj, (list, tuple)):
return f"{tn}(len={len(obj)})"
# fallback
return tn
def is_atomic(obj: Any) -> bool:
"""Atomic = we don't recurse further (except printing info)."""
return isinstance(obj, (int, float, bool, str, type(None), np.generic, np.ndarray))
# ---------- Recursive pretty-printer ----------
def describe(obj: Any, path: str = "", indent: str = "", seen: Set[int] | None = None):
if seen is None:
seen = set()
# cycle guard
oid = id(obj)
if oid in seen:
print(f"{indent}{path}: <CYCLE {one_line_sig(obj)}>")
return
seen.add(oid)
# base info line
header = (
f"{indent}{path}: {one_line_sig(obj)}"
if path
else f"{indent}{one_line_sig(obj)}"
)
print(header)
# atomic: done
if is_atomic(obj):
return
# dict: print keys and recurse on each value
if isinstance(obj, dict):
keys = list(obj.keys())
show = keys[:MAX_KEYS]
extra = len(keys) - len(show)
for k in show:
# format key nicely
key_repr = repr(k) if not isinstance(k, str) else k
next_path = f"{path}.{key_repr}" if path else f"{key_repr}"
try:
describe(obj[k], next_path, indent + INDENT, seen)
except Exception as e:
print(f"{indent}{INDENT}{next_path}: <ERROR {e}>")
if extra > 0:
print(f"{indent}{INDENT}... (+{extra} more keys)")
return
# sequence (list/tuple): group by element structure
if isinstance(obj, (list, tuple)):
n = len(obj)
if n == 0:
return
# group by element signature
groups: Dict[str, List[int]] = {}
for i, el in enumerate(obj):
s = one_line_sig(el)
groups.setdefault(s, []).append(i)
# If too many distinct groups, truncate
group_items = list(groups.items())
if len(group_items) > MAX_GROUPS:
group_items = group_items[:MAX_GROUPS]
print(
f"{indent}{INDENT}... ({len(groups) - MAX_GROUPS} more groups hidden)"
)
# print group headers and recurse into representative(s)
for sig, idxs in group_items:
count = len(idxs)
print(f"{indent}{INDENT}- elements with structure {sig}: count={count}")
# sample a few representatives from this group
for j in idxs[:SAMPLE_PER_GROUP]:
rep = obj[j]
rep_path = f"{path}[{j}]" if path else f"[{j}]"
try:
if is_atomic(rep):
print(
f"{indent}{INDENT}{INDENT}{rep_path}: {one_line_sig(rep)}"
)
else:
describe(rep, rep_path, indent + INDENT * 2, seen)
except Exception as e:
print(f"{indent}{INDENT}{INDENT}{rep_path}: <ERROR {e}>")
return
# fallback: nothing more to do
return
# ---------- Traverse & aggregate path->signature (optional summary) ----------
def traverse_paths(
obj: Any, prefix: str, out: Dict[str, Set[str]], seen: Set[int] | None = None
):
if seen is None:
seen = set()
oid = id(obj)
if oid in seen:
return
seen.add(oid)
out.setdefault(prefix or "<root>", set()).add(one_line_sig(obj))
if is_atomic(obj):
return
if isinstance(obj, dict):
for k, v in obj.items():
key = repr(k) if not isinstance(k, str) else k
path = f"{prefix}.{key}" if prefix else key
traverse_paths(v, path, out, seen)
return
if isinstance(obj, (list, tuple)):
# record element signatures but don't descend into *all* elements; just the first of each sig
sig_to_index = {}
for i, el in enumerate(obj):
s = one_line_sig(el)
if s not in sig_to_index:
sig_to_index[s] = i
for s, i in sig_to_index.items():
path = f"{prefix}[]" if prefix else "[]"
traverse_paths(obj[i], path, out, seen)
return
# ---------- Per-pickle entry point ----------
def inspect_pickle_file(pkl_path: Path) -> Dict[str, Set[str]]:
with pkl_path.open("rb") as f:
data = pickle.load(f)
print(f"\n=== {pkl_path.name} ===")
print("Top-level structure:")
describe(data)
# optional: aggregate a concise path->signature summary for later
agg: Dict[str, Set[str]] = {}
traverse_paths(data, "", agg)
return agg
# ---------- Per-experiment ----------
def read_kfold_num(exp_dir: Path) -> int:
cfg = exp_dir / "config.json"
if not cfg.exists():
print(f"[warn] {exp_dir.name}: missing config.json")
return 0
with cfg.open("r") as f:
config = json.load(f)
if not config.get("k_fold"):
print(f"[warn] {exp_dir.name}: config says not k-fold")
return 0
return int(config["k_fold_num"])
def inspect_experiment_folder(exp_dir: Path, models: List[str]) -> Dict[str, Set[str]]:
k = read_kfold_num(exp_dir)
if k <= 0:
return {}
agg: Dict[str, Set[str]] = {}
for model in models:
pkl = exp_dir / f"results_{model}_0.pkl" # fold 0 by design
if not pkl.exists():
print(f"[warn] Missing {pkl.name} in {exp_dir.name}")
continue
try:
sigs = inspect_pickle_file(pkl)
# merge
for path, sigset in sigs.items():
agg.setdefault(path, set()).update(sigset)
except Exception as e:
print(f"[error] Failed reading {pkl}: {e}")
return agg
# ---------- Driver ----------
def main(root: Path, models: List[str]):
if not root.exists():
print(f"[error] ROOT not found: {root.resolve()}")
return
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
if not exp_dirs:
print(f"[warn] No experiment subdirectories under {root}")
return
print(f"Found {len(exp_dirs)} experiment dirs under {root}\n")
efficient_done, lenet_done = False, False
global_agg: Dict[str, Set[str]] = {}
for exp in sorted(exp_dirs):
if efficient_done and lenet_done:
print("\nBoth efficient and lenet done, stopping early.")
break
lowname = exp.name.lower()
if efficient_done and "efficient" in lowname:
continue
if lenet_done and "lenet" in lowname:
continue
print(f"\n================ EXPERIMENT: {exp.name} ================")
agg = inspect_experiment_folder(exp, models)
if agg:
if "efficient" in lowname:
efficient_done = True
if "lenet" in lowname:
lenet_done = True
for path, sigset in agg.items():
global_agg.setdefault(path, set()).update(sigset)
print("\n\n================ GLOBAL STRUCTURE SUMMARY ================")
for path in sorted(global_agg.keys()):
shapes = sorted(global_agg[path])
print(f"\n{path}:")
for s in shapes:
print(f" - {s}")
print("\nDone.")
if __name__ == "__main__":
main(ROOT, MODELS)