fixed plots
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
import polars as pl
|
||||
|
||||
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||||
from plot_scripts.load_results import load_pretraining_results_dataframe
|
||||
from load_results import load_pretraining_results_dataframe
|
||||
|
||||
# ----------------------------
|
||||
# Config
|
||||
@@ -78,8 +78,8 @@ def build_arch_curves_from_df(
|
||||
"overall": (dims, means, stds),
|
||||
} }
|
||||
"""
|
||||
if "split" not in df.columns:
|
||||
raise ValueError("Expected 'split' column in AE dataframe.")
|
||||
# if "split" not in df.columns:
|
||||
# raise ValueError("Expected 'split' column in AE dataframe.")
|
||||
if "scores" not in df.columns:
|
||||
raise ValueError("Expected 'scores' column in AE dataframe.")
|
||||
if "network" not in df.columns or "latent_dim" not in df.columns:
|
||||
@@ -88,7 +88,7 @@ def build_arch_curves_from_df(
|
||||
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
|
||||
|
||||
# Keep only test split
|
||||
df = df.filter(pl.col("split") == "test")
|
||||
# df = df.filter(pl.col("split") == "test")
|
||||
|
||||
groups: dict[tuple[str, int], dict[str, list[float]]] = {}
|
||||
|
||||
@@ -201,7 +201,7 @@ def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
||||
|
||||
plt.xlabel("Latent Dimensionality")
|
||||
plt.ylabel("Test Loss")
|
||||
plt.title(title)
|
||||
# plt.title(title)
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.xticks(all_dims)
|
||||
|
||||
Reference in New Issue
Block a user