reworked results chpt

This commit is contained in:
Jan Kowalczyk
2025-09-27 19:01:59 +02:00
parent c270783225
commit e00d1a33e3
4 changed files with 36 additions and 52 deletions

View File

@@ -143,6 +143,11 @@ def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float,
return (float(y0), float(y1))
def _get_dim_mapping(dims: list[int]) -> dict[int, int]:
"""Map actual dimensions to evenly spaced positions (0, 1, 2, ...)"""
return {dim: i for i, dim in enumerate(dims)}
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
fig, axes = plt.subplots(
len(SEMI_LABELING_REGIMES),
@@ -155,6 +160,9 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
if len(SEMI_LABELING_REGIMES) == 1:
axes = [axes]
# Create dimension mapping
dim_mapping = _get_dim_mapping(LATENT_DIMS)
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
semi_n, semi_a = regime
data = {}
@@ -163,7 +171,9 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
for dim in LATENT_DIMS:
key = (ev, net, dim, semi_n, semi_a)
if key in agg:
xs.append(dim)
xs.append(
dim_mapping[dim]
) # Use mapped position instead of actual dim
ys.append(agg[key].mean)
es.append(agg[key].std)
data[net] = (xs, ys, es)
@@ -172,12 +182,16 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
xs, ys, es = data[net]
if not xs:
continue
ax.set_xticks(LATENT_DIMS)
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
# Set evenly spaced ticks with actual dimension labels
ax.set_xticks(list(dim_mapping.values()))
ax.set_xticklabels(LATENT_DIMS)
ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
ax.scatter(
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
)
x_fit, y_fit = _lin_trend(xs, ys)
x_fit, y_fit = _lin_trend(xs, ys) # Now using mapped positions
ax.plot(
x_fit,
y_fit,