fixed plots

This commit is contained in:
Jan Kowalczyk
2025-10-21 19:04:19 +02:00
parent 8f983b890f
commit 7b5accb6c5
25 changed files with 1917 additions and 165 deletions

View File

@@ -12,7 +12,7 @@ import numpy as np
import polars as pl
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
from plot_scripts.load_results import load_pretraining_results_dataframe
from load_results import load_pretraining_results_dataframe
# ----------------------------
# Config
@@ -78,8 +78,8 @@ def build_arch_curves_from_df(
"overall": (dims, means, stds),
} }
"""
if "split" not in df.columns:
raise ValueError("Expected 'split' column in AE dataframe.")
# if "split" not in df.columns:
# raise ValueError("Expected 'split' column in AE dataframe.")
if "scores" not in df.columns:
raise ValueError("Expected 'scores' column in AE dataframe.")
if "network" not in df.columns or "latent_dim" not in df.columns:
@@ -88,7 +88,7 @@ def build_arch_curves_from_df(
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
# Keep only test split
df = df.filter(pl.col("split") == "test")
# df = df.filter(pl.col("split") == "test")
groups: dict[tuple[str, int], dict[str, list[float]]] = {}
@@ -201,7 +201,7 @@ def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
plt.xlabel("Latent Dimensionality")
plt.ylabel("Test Loss")
plt.title(title)
# plt.title(title)
plt.legend()
plt.grid(True, alpha=0.3)
plt.xticks(all_dims)