Compare commits
2 Commits
33de01b150
...
ed80faf1e2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed80faf1e2 | ||
|
|
3d968c305c |
71
tools/demo_loaded_data.py
Normal file
71
tools/demo_loaded_data.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from load_results import load_pretraining_results_dataframe, load_results_dataframe
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Example “analysis-ready” queries (Polars idioms)
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def demo_queries(df: pl.DataFrame):
|
||||||
|
# q1: lazy is fine, then collect
|
||||||
|
q1 = (
|
||||||
|
df.lazy()
|
||||||
|
.filter(
|
||||||
|
(pl.col("network") == "LeNet")
|
||||||
|
& (pl.col("latent_dim") == 1024)
|
||||||
|
& (pl.col("semi_normals") == 0)
|
||||||
|
& (pl.col("semi_anomalous") == 0)
|
||||||
|
& (pl.col("eval") == "exp_based")
|
||||||
|
)
|
||||||
|
.group_by(["model"])
|
||||||
|
.agg(pl.col("auc").mean().alias("mean_auc"))
|
||||||
|
.sort(["mean_auc"], descending=True)
|
||||||
|
.collect()
|
||||||
|
)
|
||||||
|
|
||||||
|
# q2: do the filtering eagerly, then pivot (LazyFrame has no .pivot)
|
||||||
|
base = df.filter(
|
||||||
|
(pl.col("model") == "deepsad")
|
||||||
|
& (pl.col("eval") == "exp_based")
|
||||||
|
& (pl.col("network") == "LeNet")
|
||||||
|
& (pl.col("semi_normals") == 0)
|
||||||
|
& (pl.col("semi_anomalous") == 0)
|
||||||
|
).select("fold", "latent_dim", "auc")
|
||||||
|
q2 = base.pivot(
|
||||||
|
values="auc",
|
||||||
|
index="fold",
|
||||||
|
columns="latent_dim",
|
||||||
|
aggregate_function="first", # or "mean" if duplicates exist
|
||||||
|
).sort("fold")
|
||||||
|
|
||||||
|
# roc_subset: eager filter/select, then explode struct fields
|
||||||
|
roc_subset = (
|
||||||
|
df.filter(
|
||||||
|
(pl.col("model") == "ocsvm")
|
||||||
|
& (pl.col("eval") == "manual_based")
|
||||||
|
& (pl.col("network") == "efficient")
|
||||||
|
& (pl.col("latent_dim") == 1024)
|
||||||
|
& (pl.col("semi_normals") == 0)
|
||||||
|
& (pl.col("semi_anomalous") == 0)
|
||||||
|
)
|
||||||
|
.select("fold", "roc_curve")
|
||||||
|
.with_columns(
|
||||||
|
pl.col("roc_curve").struct.field("fpr").alias("fpr"),
|
||||||
|
pl.col("roc_curve").struct.field("tpr").alias("tpr"),
|
||||||
|
pl.col("roc_curve").struct.field("thr").alias("thr"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return q1, q2, roc_subset
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root = Path("/home/fedex/mt/results/done")
|
||||||
|
df = load_results_dataframe(root, allow_cache=True)
|
||||||
|
demo_queries(df)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
103
tools/devenv.lock
Normal file
103
tools/devenv.lock
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"devenv": {
|
||||||
|
"locked": {
|
||||||
|
"dir": "src/modules",
|
||||||
|
"lastModified": 1754730435,
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv",
|
||||||
|
"rev": "d1388a093a7225c2abe8c244109c5a4490de4077",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"dir": "src/modules",
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"flake-compat": {
|
||||||
|
"flake": false,
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1747046372,
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"git-hooks": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-compat": "flake-compat",
|
||||||
|
"gitignore": "gitignore",
|
||||||
|
"nixpkgs": [
|
||||||
|
"nixpkgs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1754416808,
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "git-hooks.nix",
|
||||||
|
"rev": "9c52372878df6911f9afc1e2a1391f55e4dfc864",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "git-hooks.nix",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gitignore": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": [
|
||||||
|
"git-hooks",
|
||||||
|
"nixpkgs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1709087332,
|
||||||
|
"owner": "hercules-ci",
|
||||||
|
"repo": "gitignore.nix",
|
||||||
|
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "hercules-ci",
|
||||||
|
"repo": "gitignore.nix",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1754299112,
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv-nixpkgs",
|
||||||
|
"rev": "16c21c9f5c6fb978466e91182a248dd8ca1112ac",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "cachix",
|
||||||
|
"ref": "rolling",
|
||||||
|
"repo": "devenv-nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"devenv": "devenv",
|
||||||
|
"git-hooks": "git-hooks",
|
||||||
|
"nixpkgs": "nixpkgs",
|
||||||
|
"pre-commit-hooks": [
|
||||||
|
"git-hooks"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
27
tools/devenv.nix
Normal file
27
tools/devenv.nix
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{ pkgs, ... }:
|
||||||
|
let
|
||||||
|
native_dependencies = with pkgs.python312Packages; [
|
||||||
|
torch-bin
|
||||||
|
torchvision-bin
|
||||||
|
aggdraw # for visualtorch
|
||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
matplotlib
|
||||||
|
];
|
||||||
|
tools = with pkgs; [
|
||||||
|
ruff
|
||||||
|
];
|
||||||
|
in
|
||||||
|
{
|
||||||
|
packages = native_dependencies ++ tools;
|
||||||
|
languages.python = {
|
||||||
|
enable = true;
|
||||||
|
package = pkgs.python312;
|
||||||
|
uv = {
|
||||||
|
enable = true;
|
||||||
|
sync.enable = true;
|
||||||
|
};
|
||||||
|
venv.enable = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
17
tools/devenv.yaml
Normal file
17
tools/devenv.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json
|
||||||
|
inputs:
|
||||||
|
nixpkgs:
|
||||||
|
url: github:cachix/devenv-nixpkgs/rolling
|
||||||
|
|
||||||
|
allowUnfree: true
|
||||||
|
cudaSupport: true
|
||||||
|
# If you're using non-OSS software, you can set allowUnfree to true.
|
||||||
|
# allowUnfree: true
|
||||||
|
|
||||||
|
# If you're willing to use a package that's vulnerable
|
||||||
|
# permittedInsecurePackages:
|
||||||
|
# - "openssl-1.1.1w"
|
||||||
|
|
||||||
|
# If you have more than one devenv you can merge them
|
||||||
|
#imports:
|
||||||
|
# - ./backend
|
||||||
192
tools/flake.lock
generated
192
tools/flake.lock
generated
@@ -1,192 +0,0 @@
|
|||||||
{
|
|
||||||
"nodes": {
|
|
||||||
"flake-utils": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"flake-utils_2": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems_2"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nix-github-actions": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1703863825,
|
|
||||||
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1718676766,
|
|
||||||
"narHash": "sha256-0se0JqeNSZcNmqhsHMN9N4cVV/XkPhtSVJwhLs2RGUg=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "31e107dc564e53cf2843bedf6a8b85faa2f845e3",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "NixOS",
|
|
||||||
"ref": "nixos-unstable-small",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs-newest": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1745391562,
|
|
||||||
"narHash": "sha256-sPwcCYuiEopaafePqlG826tBhctuJsLx/mhKKM5Fmjo=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "8a2f738d9d1f1d986b5a4cd2fd2061a7127237d7",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "NixOS",
|
|
||||||
"ref": "nixos-unstable",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"poetry2nix": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils_2",
|
|
||||||
"nix-github-actions": "nix-github-actions",
|
|
||||||
"nixpkgs": [
|
|
||||||
"nixpkgs"
|
|
||||||
],
|
|
||||||
"systems": "systems_3",
|
|
||||||
"treefmt-nix": "treefmt-nix"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1718656656,
|
|
||||||
"narHash": "sha256-/8pXTFOfb7+KrFi+g8G/dFehDkc96/O5eL8L+FjzG1w=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"rev": "2c6d07717af20e45fa5b2c823729126be91a3cdf",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils",
|
|
||||||
"nixpkgs": "nixpkgs",
|
|
||||||
"nixpkgs-newest": "nixpkgs-newest",
|
|
||||||
"poetry2nix": "poetry2nix"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_2": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_3": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"id": "systems",
|
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"treefmt-nix": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1718522839,
|
|
||||||
"narHash": "sha256-ULzoKzEaBOiLRtjeY3YoGFJMwWSKRYOic6VNw2UyTls=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"rev": "68eb1dc333ce82d0ab0c0357363ea17c31ea1f81",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": "root",
|
|
||||||
"version": 7
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
{
|
|
||||||
description = "Application packaged using poetry2nix";
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
|
||||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
|
|
||||||
# Added newest nixpkgs for an updated poetry package.
|
|
||||||
nixpkgs-newest.url = "github:NixOS/nixpkgs/nixos-unstable";
|
|
||||||
poetry2nix = {
|
|
||||||
url = "github:nix-community/poetry2nix";
|
|
||||||
inputs.nixpkgs.follows = "nixpkgs";
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
outputs =
|
|
||||||
{
|
|
||||||
self,
|
|
||||||
nixpkgs,
|
|
||||||
flake-utils,
|
|
||||||
poetry2nix,
|
|
||||||
nixpkgs-newest,
|
|
||||||
}:
|
|
||||||
flake-utils.lib.eachDefaultSystem (
|
|
||||||
system:
|
|
||||||
let
|
|
||||||
pkgs = nixpkgs.legacyPackages.${system};
|
|
||||||
# Use the newest nixpkgs exclusively for the poetry package.
|
|
||||||
pkgsNew = nixpkgs-newest.legacyPackages.${system};
|
|
||||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication defaultPoetryOverrides;
|
|
||||||
inherit poetry2nix;
|
|
||||||
in
|
|
||||||
{
|
|
||||||
packages = {
|
|
||||||
myapp = mkPoetryApplication {
|
|
||||||
projectDir = self;
|
|
||||||
preferWheels = true;
|
|
||||||
overrides = defaultPoetryOverrides.extend (
|
|
||||||
self: super: {
|
|
||||||
umap = super.umap.overridePythonAttrs (old: {
|
|
||||||
buildInputs = (old.buildInputs or [ ]) ++ [ super.setuptools ];
|
|
||||||
});
|
|
||||||
}
|
|
||||||
);
|
|
||||||
};
|
|
||||||
default = self.packages.${system}.myapp;
|
|
||||||
};
|
|
||||||
|
|
||||||
# Shell for app dependencies.
|
|
||||||
#
|
|
||||||
# nix develop
|
|
||||||
#
|
|
||||||
# Use this shell for developing your app.
|
|
||||||
devShells.default = pkgs.mkShell {
|
|
||||||
inputsFrom = [
|
|
||||||
self.packages.${system}.myapp
|
|
||||||
];
|
|
||||||
};
|
|
||||||
|
|
||||||
# Shell for poetry.
|
|
||||||
#
|
|
||||||
# nix develop .#poetry
|
|
||||||
#
|
|
||||||
# Here we use the poetry package from the newest nixpkgs input while keeping
|
|
||||||
# all other dependencies locked.
|
|
||||||
devShells.poetry = pkgs.mkShell {
|
|
||||||
packages = [ pkgsNew.poetry ];
|
|
||||||
};
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
597
tools/load_results.py
Normal file
597
tools/load_results.py
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Config you can tweak
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
MODELS = ["deepsad", "isoforest", "ocsvm"]
|
||||||
|
EVALS = ["exp_based", "manual_based"]
|
||||||
|
|
||||||
|
SCHEMA_STATIC = {
|
||||||
|
# identifiers / dims
|
||||||
|
"network": pl.Utf8, # e.g. "LeNet", "efficient"
|
||||||
|
"latent_dim": pl.Int32,
|
||||||
|
"semi_normals": pl.Int32,
|
||||||
|
"semi_anomalous": pl.Int32,
|
||||||
|
"model": pl.Utf8, # "deepsad" | "isoforest" | "ocsvm"
|
||||||
|
"eval": pl.Utf8, # "exp_based" | "manual_based"
|
||||||
|
"fold": pl.Int32,
|
||||||
|
# metrics
|
||||||
|
"auc": pl.Float64,
|
||||||
|
"ap": pl.Float64,
|
||||||
|
# per-sample scores: list of (idx, label, score)
|
||||||
|
"scores": pl.List(
|
||||||
|
pl.Struct(
|
||||||
|
{
|
||||||
|
"sample_idx": pl.Int32, # dataloader idx
|
||||||
|
"orig_label": pl.Int8, # {-1,0,1}
|
||||||
|
"score": pl.Float64, # anomaly score
|
||||||
|
}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
# curves (normalized)
|
||||||
|
"roc_curve": pl.Struct(
|
||||||
|
{
|
||||||
|
"fpr": pl.List(pl.Float64),
|
||||||
|
"tpr": pl.List(pl.Float64),
|
||||||
|
"thr": pl.List(pl.Float64),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"prc_curve": pl.Struct(
|
||||||
|
{
|
||||||
|
"precision": pl.List(pl.Float64),
|
||||||
|
"recall": pl.List(pl.Float64),
|
||||||
|
"thr": pl.List(pl.Float64), # may be len(precision)-1
|
||||||
|
}
|
||||||
|
),
|
||||||
|
# deepsad-only per-eval arrays (None for other models)
|
||||||
|
"sample_indices": pl.List(pl.Int32),
|
||||||
|
"sample_labels": pl.List(pl.Int8),
|
||||||
|
"valid_mask": pl.List(pl.Boolean),
|
||||||
|
# timings / housekeeping
|
||||||
|
"train_time": pl.Float64,
|
||||||
|
"test_time": pl.Float64,
|
||||||
|
"folder": pl.Utf8,
|
||||||
|
"k_fold_num": pl.Int32,
|
||||||
|
"config_json": pl.Utf8, # full config.json as string (for reference)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pretraining-only (AE) schema
|
||||||
|
# Pretraining-only (AE) schema — lighter defaults
|
||||||
|
PRETRAIN_SCHEMA = {
|
||||||
|
# identifiers / dims
|
||||||
|
"network": pl.Utf8, # e.g. "LeNet", "efficient"
|
||||||
|
"latent_dim": pl.Int32,
|
||||||
|
"semi_normals": pl.Int32,
|
||||||
|
"semi_anomalous": pl.Int32,
|
||||||
|
"model": pl.Utf8, # always "ae"
|
||||||
|
"fold": pl.Int32,
|
||||||
|
"split": pl.Utf8, # "train" | "test"
|
||||||
|
# timings and optimization
|
||||||
|
"time": pl.Float64,
|
||||||
|
"loss": pl.Float64,
|
||||||
|
# per-sample arrays (as lists)
|
||||||
|
"indices": pl.List(pl.Int32),
|
||||||
|
"labels_exp_based": pl.List(pl.Int32),
|
||||||
|
"labels_manual_based": pl.List(pl.Int32),
|
||||||
|
"semi_targets": pl.List(pl.Int32),
|
||||||
|
"file_ids": pl.List(pl.Int32),
|
||||||
|
"frame_ids": pl.List(pl.Int32),
|
||||||
|
"scores": pl.List(pl.Float32), # <— use Float32 to match source and save space
|
||||||
|
# file id -> name mapping from the result dict
|
||||||
|
"file_names": pl.List(pl.Struct({"file_id": pl.Int32, "name": pl.Utf8})),
|
||||||
|
# housekeeping
|
||||||
|
"folder": pl.Utf8,
|
||||||
|
"k_fold_num": pl.Int32,
|
||||||
|
"config_json": pl.Utf8, # full config.json as string (for reference)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Helpers: curve/scores normalizers (tuples/ndarrays -> dict/list)
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def _tolist(x):
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return x.tolist()
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return list(x)
|
||||||
|
# best-effort scalar wrap
|
||||||
|
try:
|
||||||
|
return [x]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_float_list(a) -> Optional[List[float]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return [None if x is None else float(x) for x in a]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_file_names(d) -> Optional[List[dict]]:
|
||||||
|
"""
|
||||||
|
Convert the 'file_names' dict (keys like numpy.int64 -> str) to a
|
||||||
|
list[ {file_id:int, name:str} ], sorted by file_id.
|
||||||
|
"""
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return None
|
||||||
|
out: List[dict] = []
|
||||||
|
for k, v in d.items():
|
||||||
|
try:
|
||||||
|
file_id = int(k)
|
||||||
|
except Exception:
|
||||||
|
# keys are printed as np.int64 in the structure; best-effort cast
|
||||||
|
continue
|
||||||
|
out.append({"file_id": file_id, "name": str(v)})
|
||||||
|
out.sort(key=lambda x: x["file_id"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_roc(obj: Any) -> Optional[dict]:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
fpr = tpr = thr = None
|
||||||
|
if isinstance(obj, (tuple, list)):
|
||||||
|
if len(obj) >= 2:
|
||||||
|
fpr, tpr = _tolist(obj[0]), _tolist(obj[1])
|
||||||
|
if len(obj) >= 3:
|
||||||
|
thr = _tolist(obj[2])
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
fpr = _tolist(obj.get("fpr") or obj.get("x"))
|
||||||
|
tpr = _tolist(obj.get("tpr") or obj.get("y"))
|
||||||
|
thr = _tolist(obj.get("thr") or obj.get("thresholds"))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if fpr is None or tpr is None:
|
||||||
|
return None
|
||||||
|
return {"fpr": fpr, "tpr": tpr, "thr": thr}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_prc(obj: Any) -> Optional[dict]:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
precision = recall = thr = None
|
||||||
|
if isinstance(obj, (tuple, list)):
|
||||||
|
if len(obj) >= 2:
|
||||||
|
precision, recall = _tolist(obj[0]), _tolist(obj[1])
|
||||||
|
if len(obj) >= 3:
|
||||||
|
thr = _tolist(obj[2])
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
precision = _tolist(obj.get("precision") or obj.get("y"))
|
||||||
|
recall = _tolist(obj.get("recall") or obj.get("x"))
|
||||||
|
thr = _tolist(obj.get("thr") or obj.get("thresholds"))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if precision is None or recall is None:
|
||||||
|
return None
|
||||||
|
return {"precision": precision, "recall": recall, "thr": thr}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_scores_to_struct(seq) -> Optional[List[dict]]:
|
||||||
|
"""
|
||||||
|
Input: list of (idx, label, score) tuples (as produced in your test()).
|
||||||
|
Output: list of dicts with keys sample_idx, orig_label, score.
|
||||||
|
"""
|
||||||
|
if seq is None:
|
||||||
|
return None
|
||||||
|
if isinstance(seq, np.ndarray):
|
||||||
|
seq = seq.tolist()
|
||||||
|
if not isinstance(seq, (list, tuple)):
|
||||||
|
return None
|
||||||
|
out: List[dict] = []
|
||||||
|
for item in seq:
|
||||||
|
if isinstance(item, (list, tuple)) and len(item) >= 3:
|
||||||
|
idx, lab, sc = item[0], item[1], item[2]
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"sample_idx": None if idx is None else int(idx),
|
||||||
|
"orig_label": None if lab is None else int(lab),
|
||||||
|
"score": None if sc is None else float(sc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# fallback: single numeric -> score
|
||||||
|
sc = (
|
||||||
|
float(item)
|
||||||
|
if isinstance(item, (int, float, np.integer, np.floating))
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
out.append({"sample_idx": None, "orig_label": None, "score": sc})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_int_list(a) -> Optional[List[int]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return list(a)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_bool_list(a) -> Optional[List[bool]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return [bool(x) for x in a]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Low-level: read one experiment folder
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def read_config(exp_dir: Path) -> dict:
|
||||||
|
cfg = exp_dir / "config.json"
|
||||||
|
with cfg.open("r") as f:
|
||||||
|
c = json.load(f)
|
||||||
|
if not c.get("k_fold"):
|
||||||
|
raise ValueError(f"{exp_dir.name}: not trained as k-fold")
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def read_pickle(p: Path) -> Any:
|
||||||
|
with p.open("rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Extractors for each model
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def rows_from_deepsad(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
deepsad under data['test'][eval], with extra per-eval arrays and AP present.
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
test = data.get("test", {})
|
||||||
|
for ev in evals:
|
||||||
|
evd = test.get(ev)
|
||||||
|
if not isinstance(evd, dict):
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(evd["auc"])
|
||||||
|
if "auc" in evd and evd["auc"] is not None
|
||||||
|
else None,
|
||||||
|
"roc": normalize_roc(evd.get("roc")),
|
||||||
|
"prc": normalize_prc(evd.get("prc")),
|
||||||
|
"ap": float(evd["ap"]) if "ap" in evd and evd["ap"] is not None else None,
|
||||||
|
"scores": normalize_scores_to_struct(evd.get("scores")),
|
||||||
|
"sample_indices": normalize_int_list(evd.get("indices")),
|
||||||
|
"sample_labels": normalize_int_list(evd.get("labels")),
|
||||||
|
"valid_mask": normalize_bool_list(evd.get("valid_mask")),
|
||||||
|
"train_time": data.get("train", {}).get("time"),
|
||||||
|
"test_time": test.get("time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def rows_from_isoforest(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Keys: test_auc_<eval>, test_roc_<eval>, test_prc_<eval>, test_ap_<eval>, test_scores_<eval>.
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
for ev in evals:
|
||||||
|
auc = data.get(f"test_auc_{ev}")
|
||||||
|
if auc is None:
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(auc),
|
||||||
|
"roc": normalize_roc(data.get(f"test_roc_{ev}")),
|
||||||
|
"prc": normalize_prc(data.get(f"test_prc_{ev}")),
|
||||||
|
"ap": float(data.get(f"test_ap_{ev}"))
|
||||||
|
if data.get(f"test_ap_{ev}") is not None
|
||||||
|
else None,
|
||||||
|
"scores": normalize_scores_to_struct(data.get(f"test_scores_{ev}")),
|
||||||
|
"sample_indices": None,
|
||||||
|
"sample_labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
"train_time": data.get("train_time"),
|
||||||
|
"test_time": data.get("test_time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def rows_from_ocsvm_default(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Default OCSVM only (ignore linear variant entirely).
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
for ev in evals:
|
||||||
|
auc = data.get(f"test_auc_{ev}")
|
||||||
|
if auc is None:
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(auc),
|
||||||
|
"roc": normalize_roc(data.get(f"test_roc_{ev}")),
|
||||||
|
"prc": normalize_prc(data.get(f"test_prc_{ev}")),
|
||||||
|
"ap": float(data.get(f"test_ap_{ev}"))
|
||||||
|
if data.get(f"test_ap_{ev}") is not None
|
||||||
|
else None,
|
||||||
|
"scores": normalize_scores_to_struct(data.get(f"test_scores_{ev}")),
|
||||||
|
"sample_indices": None,
|
||||||
|
"sample_labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
"train_time": data.get("train_time"),
|
||||||
|
"test_time": data.get("test_time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Build the Polars DataFrame
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Walks experiment subdirs under `root`. For each (model, fold) it adds rows:
|
||||||
|
Columns (SCHEMA_STATIC):
|
||||||
|
network, latent_dim, semi_normals, semi_anomalous,
|
||||||
|
model, eval, fold,
|
||||||
|
auc, ap, scores{sample_idx,orig_label,score},
|
||||||
|
roc_curve{fpr,tpr,thr}, prc_curve{precision,recall,thr},
|
||||||
|
sample_indices, sample_labels, valid_mask,
|
||||||
|
train_time, test_time,
|
||||||
|
folder, k_fold_num
|
||||||
|
"""
|
||||||
|
if allow_cache:
|
||||||
|
cache = root / "results_cache.parquet"
|
||||||
|
if cache.exists():
|
||||||
|
try:
|
||||||
|
df = pl.read_parquet(cache)
|
||||||
|
print(f"[info] loaded cached results frame from {cache}")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to load cache {cache}: {e}")
|
||||||
|
|
||||||
|
rows: List[dict] = []
|
||||||
|
|
||||||
|
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
|
||||||
|
for exp_dir in sorted(exp_dirs):
|
||||||
|
try:
|
||||||
|
cfg = read_config(exp_dir)
|
||||||
|
cfg_json = json.dumps(cfg, sort_keys=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] skipping {exp_dir.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
network = cfg.get("net_name")
|
||||||
|
latent_dim = int(cfg.get("latent_space_dim"))
|
||||||
|
semi_normals = int(cfg.get("num_known_normal"))
|
||||||
|
semi_anomalous = int(cfg.get("num_known_outlier"))
|
||||||
|
k = int(cfg.get("k_fold_num"))
|
||||||
|
|
||||||
|
for model in MODELS:
|
||||||
|
for fold in range(k):
|
||||||
|
pkl = exp_dir / f"results_{model}_{fold}.pkl"
|
||||||
|
if not pkl.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = read_pickle(pkl)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to read {pkl.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model == "deepsad":
|
||||||
|
per_eval = rows_from_deepsad(data, EVALS) # eval -> dict
|
||||||
|
elif model == "isoforest":
|
||||||
|
per_eval = rows_from_isoforest(data, EVALS) # eval -> dict
|
||||||
|
elif model == "ocsvm":
|
||||||
|
per_eval = rows_from_ocsvm_default(data, EVALS) # eval -> dict
|
||||||
|
else:
|
||||||
|
per_eval = {}
|
||||||
|
|
||||||
|
for ev, vals in per_eval.items():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"network": network,
|
||||||
|
"latent_dim": latent_dim,
|
||||||
|
"semi_normals": semi_normals,
|
||||||
|
"semi_anomalous": semi_anomalous,
|
||||||
|
"model": model,
|
||||||
|
"eval": ev,
|
||||||
|
"fold": fold,
|
||||||
|
"auc": vals["auc"],
|
||||||
|
"ap": vals["ap"],
|
||||||
|
"scores": vals["scores"],
|
||||||
|
"roc_curve": vals["roc"],
|
||||||
|
"prc_curve": vals["prc"],
|
||||||
|
"sample_indices": vals.get("sample_indices"),
|
||||||
|
"sample_labels": vals.get("sample_labels"),
|
||||||
|
"valid_mask": vals.get("valid_mask"),
|
||||||
|
"train_time": vals["train_time"],
|
||||||
|
"test_time": vals["test_time"],
|
||||||
|
"folder": str(exp_dir),
|
||||||
|
"k_fold_num": k,
|
||||||
|
"config_json": cfg_json,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# If empty, return a typed empty frame
|
||||||
|
if not rows:
|
||||||
|
return pl.DataFrame(schema=SCHEMA_STATIC)
|
||||||
|
|
||||||
|
df = pl.DataFrame(rows, schema=SCHEMA_STATIC)
|
||||||
|
|
||||||
|
# Cast to efficient dtypes (categoricals etc.) – no extra sanitation
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("network", "model", "eval").cast(pl.Categorical),
|
||||||
|
pl.col(
|
||||||
|
"latent_dim", "semi_normals", "semi_anomalous", "fold", "k_fold_num"
|
||||||
|
).cast(pl.Int32),
|
||||||
|
pl.col("auc", "ap", "train_time", "test_time").cast(pl.Float64),
|
||||||
|
# NOTE: no cast on 'scores' here; it's already List(Struct) per schema.
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_cache:
|
||||||
|
try:
|
||||||
|
df.write_parquet(cache)
|
||||||
|
print(f"[info] cached results frame to {cache}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to write cache {cache}: {e}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretraining_results_dataframe(
|
||||||
|
root: Path,
|
||||||
|
allow_cache: bool = True,
|
||||||
|
include_train: bool = False, # <— default: store only TEST to keep cache tiny
|
||||||
|
keep_file_names: bool = False, # <— drop file_names by default; they’re repeated
|
||||||
|
parquet_compression: str = "zstd",
|
||||||
|
parquet_compression_level: int = 7, # <— stronger compression than default
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Loads only AE pretraining results: files named `results_ae_<fold>.pkl`.
|
||||||
|
Produces one row per (experiment, fold, split). By default we:
|
||||||
|
- include only the TEST split (include_train=False)
|
||||||
|
- store scores as Float32
|
||||||
|
- drop the repeated file_names mapping to save space
|
||||||
|
- write Parquet with zstd(level=7)
|
||||||
|
"""
|
||||||
|
if allow_cache:
|
||||||
|
cache = root / "pretraining_results_cache.parquet"
|
||||||
|
if cache.exists():
|
||||||
|
try:
|
||||||
|
df = pl.read_parquet(cache)
|
||||||
|
print(f"[info] loaded cached pretraining frame from {cache}")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to load pretraining cache {cache}: {e}")
|
||||||
|
|
||||||
|
rows: List[dict] = []
|
||||||
|
|
||||||
|
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
|
||||||
|
for exp_dir in sorted(exp_dirs):
|
||||||
|
try:
|
||||||
|
cfg = read_config(exp_dir)
|
||||||
|
cfg_json = json.dumps(cfg, sort_keys=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] skipping {exp_dir.name} (pretraining): {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
network = cfg.get("net_name")
|
||||||
|
latent_dim = int(cfg.get("latent_space_dim"))
|
||||||
|
semi_normals = int(cfg.get("num_known_normal"))
|
||||||
|
semi_anomalous = int(cfg.get("num_known_outlier"))
|
||||||
|
k = int(cfg.get("k_fold_num"))
|
||||||
|
|
||||||
|
# Only test split by default (include_train=False)
|
||||||
|
splits = ("train", "test") if include_train else ("test",)
|
||||||
|
|
||||||
|
for fold in range(k):
|
||||||
|
pkl = exp_dir / f"results_ae_{fold}.pkl"
|
||||||
|
if not pkl.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = read_pickle(pkl) # expected: {"train": {...}, "test": {...}}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to read {pkl.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
splitd = data.get(split)
|
||||||
|
if not isinstance(splitd, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"network": network,
|
||||||
|
"latent_dim": latent_dim,
|
||||||
|
"semi_normals": semi_normals,
|
||||||
|
"semi_anomalous": semi_anomalous,
|
||||||
|
"model": "ae",
|
||||||
|
"fold": fold,
|
||||||
|
"split": split,
|
||||||
|
"time": float(splitd.get("time"))
|
||||||
|
if splitd.get("time") is not None
|
||||||
|
else None,
|
||||||
|
"loss": float(splitd.get("loss"))
|
||||||
|
if splitd.get("loss") is not None
|
||||||
|
else None,
|
||||||
|
# ints as Int32, scores as Float32 to save space
|
||||||
|
"indices": normalize_int_list(splitd.get("indices")),
|
||||||
|
"labels_exp_based": normalize_int_list(
|
||||||
|
splitd.get("labels_exp_based")
|
||||||
|
),
|
||||||
|
"labels_manual_based": normalize_int_list(
|
||||||
|
splitd.get("labels_manual_based")
|
||||||
|
),
|
||||||
|
"semi_targets": normalize_int_list(splitd.get("semi_targets")),
|
||||||
|
"file_ids": normalize_int_list(splitd.get("file_ids")),
|
||||||
|
"frame_ids": normalize_int_list(splitd.get("frame_ids")),
|
||||||
|
"scores": (
|
||||||
|
None
|
||||||
|
if splitd.get("scores") is None
|
||||||
|
else [
|
||||||
|
float(x)
|
||||||
|
for x in (
|
||||||
|
splitd["scores"].tolist()
|
||||||
|
if isinstance(splitd["scores"], np.ndarray)
|
||||||
|
else splitd["scores"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"file_names": normalize_file_names(splitd.get("file_names"))
|
||||||
|
if keep_file_names
|
||||||
|
else None,
|
||||||
|
"folder": str(exp_dir),
|
||||||
|
"k_fold_num": k,
|
||||||
|
"config_json": cfg_json,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return pl.DataFrame(schema=PRETRAIN_SCHEMA)
|
||||||
|
|
||||||
|
df = pl.DataFrame(rows, schema=PRETRAIN_SCHEMA)
|
||||||
|
|
||||||
|
# Cast/optimize a bit (categoricals, ints, floats)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("network", "model", "split").cast(pl.Categorical),
|
||||||
|
pl.col(
|
||||||
|
"latent_dim", "semi_normals", "semi_anomalous", "fold", "k_fold_num"
|
||||||
|
).cast(pl.Int32),
|
||||||
|
pl.col("time", "loss").cast(pl.Float64),
|
||||||
|
pl.col("scores").cast(pl.List(pl.Float32)), # ensure downcast took
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_cache:
|
||||||
|
try:
|
||||||
|
cache = root / "pretraining_results_cache.parquet"
|
||||||
|
df.write_parquet(
|
||||||
|
cache,
|
||||||
|
compression=parquet_compression,
|
||||||
|
compression_level=parquet_compression_level,
|
||||||
|
statistics=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[info] cached pretraining frame to {cache} "
|
||||||
|
f"({parquet_compression}, level={parquet_compression_level})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to write pretraining cache {cache}: {e}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root = Path("/home/fedex/mt/results/done")
|
||||||
|
df = load_results_dataframe(root, allow_cache=True)
|
||||||
|
print(df.shape, df.head())
|
||||||
|
|
||||||
|
df_pre = load_pretraining_results_dataframe(root, allow_cache=True)
|
||||||
|
print("pretraining:", df_pre.shape, df_pre.head())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,118 +1,176 @@
|
|||||||
import pickle
|
# ae_elbow_from_df.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import unittest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tabulate import tabulate
|
import polars as pl
|
||||||
|
|
||||||
# Configuration
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||||||
results_folders = {
|
from load_results import load_pretraining_results_dataframe
|
||||||
"LeNet": {
|
|
||||||
"path": Path(
|
|
||||||
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/test/DeepSAD/subter_ae_elbow_v2/"
|
|
||||||
),
|
|
||||||
"batch_size": 256,
|
|
||||||
},
|
|
||||||
"LeNet Efficient": {
|
|
||||||
"path": Path(
|
|
||||||
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/test/DeepSAD/subter_efficient_ae_elbow"
|
|
||||||
),
|
|
||||||
"batch_size": 64,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
output_path = Path("/home/fedex/mt/plots/ae_elbow_lenet")
|
|
||||||
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
||||||
|
|
||||||
latest_folder_path = output_path / "latest"
|
# ----------------------------
|
||||||
archive_folder_path = output_path / "archive"
|
# Config
|
||||||
output_datetime_path = output_path / datetime_folder_name
|
# ----------------------------
|
||||||
|
ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader
|
||||||
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/ae_elbow_lenet_from_df")
|
||||||
|
|
||||||
# Create output directories
|
# Which label field to use from the DF; "labels_exp_based" or "labels_manual_based"
|
||||||
output_path.mkdir(exist_ok=True, parents=True)
|
LABEL_FIELD = "labels_exp_based"
|
||||||
output_datetime_path.mkdir(exist_ok=True, parents=True)
|
|
||||||
latest_folder_path.mkdir(exist_ok=True, parents=True)
|
|
||||||
archive_folder_path.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_batch_mean_loss(scores, batch_size):
|
# ----------------------------
|
||||||
"""Calculate mean loss over batches similar to the original testing code."""
|
# Helpers
|
||||||
n_samples = len(scores)
|
# ----------------------------
|
||||||
n_batches = (n_samples + batch_size - 1) // batch_size
|
def canonicalize_network(name: str) -> str:
|
||||||
|
"""Map various net_name strings to clean labels for plotting."""
|
||||||
batch_losses = []
|
low = (name or "").lower()
|
||||||
for i in range(0, n_samples, batch_size):
|
if "lenet" in low:
|
||||||
batch_scores = scores[i : i + batch_size]
|
return "LeNet"
|
||||||
batch_losses.append(np.mean(batch_scores))
|
if "efficient" in low:
|
||||||
|
return "Efficient"
|
||||||
return np.sum(batch_losses) / n_batches
|
# fallback: show whatever was stored
|
||||||
|
return name or "unknown"
|
||||||
|
|
||||||
|
|
||||||
def test_loss_calculation(results, batch_size):
|
def calculate_batch_mean_loss(scores: np.ndarray, batch_size: int) -> float:
|
||||||
"""Test if our loss calculation matches the original implementation."""
|
"""Mean of per-batch means (matches how the original test loss was computed)."""
|
||||||
test = unittest.TestCase()
|
n = len(scores)
|
||||||
folds = results["ae_results"]
|
if n == 0:
|
||||||
dim = results["dimension"]
|
return np.nan
|
||||||
|
if batch_size <= 0:
|
||||||
|
batch_size = n # single batch fallback
|
||||||
|
n_batches = (n + batch_size - 1) // batch_size
|
||||||
|
acc = 0.0
|
||||||
|
for i in range(0, n, batch_size):
|
||||||
|
acc += float(np.mean(scores[i : i + batch_size]))
|
||||||
|
return acc / n_batches
|
||||||
|
|
||||||
for fold_key in folds:
|
|
||||||
fold_data = folds[fold_key]["test"]
|
|
||||||
scores = np.array(fold_data["scores"])
|
|
||||||
original_loss = fold_data["loss"]
|
|
||||||
calculated_loss = calculate_batch_mean_loss(scores, batch_size)
|
|
||||||
|
|
||||||
|
def extract_batch_size(cfg_json: str) -> int:
|
||||||
|
"""
|
||||||
|
Prefer AE batch size; fall back to general batch_size; then a safe default.
|
||||||
|
We only rely on config_json (no lifted fields).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
test.assertAlmostEqual(
|
cfg = json.loads(cfg_json) if cfg_json else {}
|
||||||
original_loss,
|
except Exception:
|
||||||
calculated_loss,
|
cfg = {}
|
||||||
places=5,
|
return int(cfg.get("ae_batch_size") or cfg.get("batch_size") or 256)
|
||||||
msg=f"Loss mismatch for dim={dim}, {fold_key}",
|
|
||||||
)
|
|
||||||
except AssertionError as e:
|
|
||||||
print(f"Warning: {str(e)}")
|
|
||||||
print(f"Original: {original_loss:.6f}, Calculated: {calculated_loss:.6f}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def plot_loss_curve(dims, means, stds, title, color, output_path):
|
def build_arch_curves_from_df(
|
||||||
"""Create and save a single loss curve plot."""
|
df: pl.DataFrame,
|
||||||
plt.figure(figsize=(8, 5))
|
label_field: str = "labels_exp_based",
|
||||||
plt.plot(dims, means, marker="o", color=color, label="Mean Test Loss")
|
only_nets: set[str] | None = None,
|
||||||
plt.fill_between(
|
):
|
||||||
dims,
|
"""
|
||||||
np.array(means) - np.array(stds),
|
From the AE pretraining DF, compute (dims, means, stds) for normal/anomaly/overall
|
||||||
np.array(means) + np.array(stds),
|
grouped by network and latent_dim. Returns:
|
||||||
color=color,
|
{ net_label: {
|
||||||
alpha=0.2,
|
"normal": (dims, means, stds),
|
||||||
label="Std Dev",
|
"anomaly": (dims, means, stds),
|
||||||
)
|
"overall": (dims, means, stds),
|
||||||
plt.xlabel("Latent Dimension")
|
} }
|
||||||
plt.ylabel("Test Loss")
|
"""
|
||||||
plt.title(title)
|
if "split" not in df.columns:
|
||||||
plt.legend()
|
raise ValueError("Expected 'split' column in AE dataframe.")
|
||||||
plt.grid(True, alpha=0.3)
|
if "scores" not in df.columns:
|
||||||
plt.xticks(dims)
|
raise ValueError("Expected 'scores' column in AE dataframe.")
|
||||||
plt.tight_layout()
|
if "network" not in df.columns or "latent_dim" not in df.columns:
|
||||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
raise ValueError("Expected 'network' and 'latent_dim' columns in AE dataframe.")
|
||||||
plt.close()
|
if label_field not in df.columns:
|
||||||
|
raise ValueError(f"Expected '{label_field}' column in AE dataframe.")
|
||||||
|
|
||||||
|
# Keep only test split
|
||||||
|
df = df.filter(pl.col("split") == "test")
|
||||||
|
|
||||||
|
groups: dict[tuple[str, int], dict[str, list[float]]] = {}
|
||||||
|
|
||||||
|
for row in df.iter_rows(named=True):
|
||||||
|
net_label = canonicalize_network(row["network"])
|
||||||
|
if only_nets and net_label not in only_nets:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dim = int(row["latent_dim"])
|
||||||
|
batch_size = extract_batch_size(row.get("config_json"))
|
||||||
|
scores = np.asarray(row["scores"] or [], dtype=float)
|
||||||
|
|
||||||
|
labels = row.get(label_field)
|
||||||
|
labels = np.asarray(labels, dtype=int) if labels is not None else None
|
||||||
|
|
||||||
|
overall_loss = calculate_batch_mean_loss(scores, batch_size)
|
||||||
|
|
||||||
|
# Split by labels if available; otherwise we only aggregate overall
|
||||||
|
normal_loss = np.nan
|
||||||
|
anomaly_loss = np.nan
|
||||||
|
if labels is not None and labels.size == scores.size:
|
||||||
|
normal_scores = scores[labels == 1]
|
||||||
|
anomaly_scores = scores[labels == -1]
|
||||||
|
if normal_scores.size > 0:
|
||||||
|
normal_loss = calculate_batch_mean_loss(normal_scores, batch_size)
|
||||||
|
if anomaly_scores.size > 0:
|
||||||
|
anomaly_loss = calculate_batch_mean_loss(anomaly_scores, batch_size)
|
||||||
|
|
||||||
|
key = (net_label, dim)
|
||||||
|
if key not in groups:
|
||||||
|
groups[key] = {"normal": [], "anomaly": [], "overall": []}
|
||||||
|
groups[key]["overall"].append(overall_loss)
|
||||||
|
groups[key]["normal"].append(normal_loss)
|
||||||
|
groups[key]["anomaly"].append(anomaly_loss)
|
||||||
|
|
||||||
|
# Aggregate across folds -> per (net, dim) mean/std
|
||||||
|
per_net_dims: dict[str, set[int]] = {}
|
||||||
|
for net, dim in groups:
|
||||||
|
per_net_dims.setdefault(net, set()).add(dim)
|
||||||
|
|
||||||
|
result: dict[str, dict[str, tuple[list[int], list[float], list[float]]]] = {}
|
||||||
|
for net, dims in per_net_dims.items():
|
||||||
|
dims_sorted = sorted(dims)
|
||||||
|
|
||||||
|
def collect(kind: str):
|
||||||
|
means, stds = [], []
|
||||||
|
for d in dims_sorted:
|
||||||
|
xs = [
|
||||||
|
x
|
||||||
|
for (n2, d2), v in groups.items()
|
||||||
|
if n2 == net and d2 == d
|
||||||
|
for x in v[kind]
|
||||||
|
if x is not None and not np.isnan(x)
|
||||||
|
]
|
||||||
|
if len(xs) == 0:
|
||||||
|
means.append(np.nan)
|
||||||
|
stds.append(np.nan)
|
||||||
|
else:
|
||||||
|
means.append(float(np.mean(xs)))
|
||||||
|
stds.append(float(np.std(xs)))
|
||||||
|
return dims_sorted, means, stds
|
||||||
|
|
||||||
|
result[net] = {
|
||||||
|
"normal": collect("normal"),
|
||||||
|
"anomaly": collect("anomaly"),
|
||||||
|
"overall": collect("overall"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
||||||
"""Create and save a loss curve plot with multiple architectures.
|
"""
|
||||||
|
arch_results: {arch_name: (dims, means, stds)}
|
||||||
Args:
|
|
||||||
arch_results: Dict of format {arch_name: (dims, means, stds)}
|
|
||||||
title: Plot title
|
|
||||||
output_path: Where to save the plot
|
|
||||||
colors: Optional dict of colors for each architecture
|
|
||||||
"""
|
"""
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
|
|
||||||
|
# default color map if not provided
|
||||||
if colors is None:
|
if colors is None:
|
||||||
colors = {
|
colors = {
|
||||||
"LeNet": "blue",
|
"LeNet": "tab:blue",
|
||||||
"LeNet Asymmetric": "red",
|
"Efficient": "tab:orange",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get unique dimensions across all architectures
|
# Get unique dimensions across all architectures
|
||||||
@@ -121,7 +179,17 @@ def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for arch_name, (dims, means, stds) in arch_results.items():
|
for arch_name, (dims, means, stds) in arch_results.items():
|
||||||
color = colors.get(arch_name, "gray")
|
color = colors.get(arch_name)
|
||||||
|
# Plot line
|
||||||
|
if color is None:
|
||||||
|
plt.plot(dims, means, marker="o", label=arch_name)
|
||||||
|
plt.fill_between(
|
||||||
|
dims,
|
||||||
|
np.array(means) - np.array(stds),
|
||||||
|
np.array(means) + np.array(stds),
|
||||||
|
alpha=0.2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
plt.plot(dims, means, marker="o", color=color, label=arch_name)
|
plt.plot(dims, means, marker="o", color=color, label=arch_name)
|
||||||
plt.fill_between(
|
plt.fill_between(
|
||||||
dims,
|
dims,
|
||||||
@@ -131,209 +199,71 @@ def plot_multi_loss_curve(arch_results, title, output_path, colors=None):
|
|||||||
alpha=0.2,
|
alpha=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.xlabel("Latent Dimension")
|
plt.xlabel("Latent Dimensionality")
|
||||||
plt.ylabel("Test Loss")
|
plt.ylabel("Test Loss")
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True, alpha=0.3)
|
plt.grid(True, alpha=0.3)
|
||||||
plt.xticks(all_dims) # Set x-axis ticks to match data points
|
plt.xticks(all_dims)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def evaluate_autoencoder_loss():
|
def main():
|
||||||
"""Main function to evaluate autoencoder loss across different latent dimensions."""
|
# Load AE DF (uses your cache if enabled in the loader)
|
||||||
# Results storage for each architecture
|
df = load_pretraining_results_dataframe(ROOT, allow_cache=True, include_train=False)
|
||||||
arch_results = {
|
|
||||||
name: {"dims": [], "normal": [], "anomaly": []} for name in results_folders
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process each architecture
|
# Optional: filter to just LeNet vs Efficient; drop this set() to plot all nets
|
||||||
for arch_name, config in results_folders.items():
|
wanted_nets = {"LeNet", "Efficient"}
|
||||||
results_folder = config["path"]
|
|
||||||
batch_size = config["batch_size"]
|
curves = build_arch_curves_from_df(
|
||||||
result_files = sorted(
|
df,
|
||||||
results_folder.glob("ae_elbow_results_subter_*_kfold.pkl")
|
label_field=LABEL_FIELD,
|
||||||
|
only_nets=wanted_nets,
|
||||||
)
|
)
|
||||||
|
|
||||||
dimensions = []
|
# Prepare output dirs
|
||||||
normal_means = []
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
normal_stds = []
|
ts_dir = OUTPUT_DIR / "archive" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
anomaly_means = []
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
anomaly_stds = []
|
|
||||||
|
|
||||||
# Verify loss calculation
|
def pick(kind: str):
|
||||||
print(
|
# kind in {"normal","anomaly","overall"}
|
||||||
f"\nVerifying loss calculation for {arch_name} (batch_size={batch_size})..."
|
return {name: payload[kind] for name, payload in curves.items()}
|
||||||
)
|
|
||||||
for result_file in result_files:
|
|
||||||
with open(result_file, "rb") as f:
|
|
||||||
results = pickle.load(f)
|
|
||||||
test_loss_calculation(results, batch_size)
|
|
||||||
print(f"Loss calculation verified successfully for {arch_name}!")
|
|
||||||
|
|
||||||
# Process files for this architecture
|
|
||||||
for result_file in result_files:
|
|
||||||
with open(result_file, "rb") as f:
|
|
||||||
results = pickle.load(f)
|
|
||||||
dim = int(results["dimension"])
|
|
||||||
folds = results["ae_results"]
|
|
||||||
|
|
||||||
normal_fold_losses = []
|
|
||||||
anomaly_fold_losses = []
|
|
||||||
|
|
||||||
all_scores = [] # Collect all scores for overall calculation
|
|
||||||
all_fold_scores = [] # Collect all fold scores for std calculation
|
|
||||||
|
|
||||||
for fold_key in folds:
|
|
||||||
fold_data = folds[fold_key]["test"]
|
|
||||||
scores = np.array(fold_data["scores"])
|
|
||||||
labels = np.array(fold_data["labels_exp_based"])
|
|
||||||
|
|
||||||
normal_scores = scores[labels == 1]
|
|
||||||
anomaly_scores = scores[labels == -1]
|
|
||||||
|
|
||||||
normal_fold_losses.append(
|
|
||||||
calculate_batch_mean_loss(normal_scores, batch_size)
|
|
||||||
)
|
|
||||||
anomaly_fold_losses.append(
|
|
||||||
calculate_batch_mean_loss(anomaly_scores, batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_scores.append(scores) # Add scores to all_scores
|
|
||||||
all_fold_scores.append(fold_data["scores"]) # Add fold scores
|
|
||||||
|
|
||||||
dimensions.append(dim)
|
|
||||||
normal_means.append(np.mean(normal_fold_losses))
|
|
||||||
normal_stds.append(np.std(normal_fold_losses))
|
|
||||||
anomaly_means.append(np.mean(anomaly_fold_losses))
|
|
||||||
anomaly_stds.append(np.std(anomaly_fold_losses))
|
|
||||||
|
|
||||||
# Sort by dimension
|
|
||||||
sorted_data = sorted(
|
|
||||||
zip(dimensions, normal_means, normal_stds, anomaly_means, anomaly_stds)
|
|
||||||
)
|
|
||||||
dims, n_means, n_stds, a_means, a_stds = zip(*sorted_data)
|
|
||||||
|
|
||||||
# Store results for this architecture
|
|
||||||
arch_results[arch_name] = {
|
|
||||||
"dims": dims,
|
|
||||||
"normal": (dims, n_means, n_stds),
|
|
||||||
"anomaly": (dims, a_means, a_stds),
|
|
||||||
"overall": (
|
|
||||||
dims,
|
|
||||||
[
|
|
||||||
calculate_batch_mean_loss(scores, batch_size)
|
|
||||||
for scores in all_scores
|
|
||||||
], # Use all scores
|
|
||||||
[
|
|
||||||
np.std(
|
|
||||||
[
|
|
||||||
calculate_batch_mean_loss(fold_scores, batch_size)
|
|
||||||
for fold_scores in fold_scores_list
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for fold_scores_list in all_fold_scores
|
|
||||||
],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the three plots with all architectures
|
|
||||||
plot_multi_loss_curve(
|
plot_multi_loss_curve(
|
||||||
{name: results["normal"] for name, results in arch_results.items()},
|
pick("normal"),
|
||||||
"Normal Class Test Loss vs. Latent Dimension",
|
"Normal Class Test Loss vs. Latent Dimensionality",
|
||||||
output_datetime_path / "ae_elbow_test_loss_normal.png",
|
ts_dir / "ae_elbow_test_loss_normal.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_multi_loss_curve(
|
plot_multi_loss_curve(
|
||||||
{name: results["anomaly"] for name, results in arch_results.items()},
|
pick("anomaly"),
|
||||||
"Anomaly Class Test Loss vs. Latent Dimension",
|
"Anomaly Class Test Loss vs. Latent Dimensionality",
|
||||||
output_datetime_path / "ae_elbow_test_loss_anomaly.png",
|
ts_dir / "ae_elbow_test_loss_anomaly.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_multi_loss_curve(
|
plot_multi_loss_curve(
|
||||||
{name: results["overall"] for name, results in arch_results.items()},
|
pick("overall"),
|
||||||
"Overall Test Loss vs. Latent Dimension",
|
"Overall Test Loss vs. Latent Dimensionality",
|
||||||
output_datetime_path / "ae_elbow_test_loss_overall.png",
|
ts_dir / "ae_elbow_test_loss_overall.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copy this script to preserve the code used for the outputs
|
||||||
|
script_path = Path(__file__)
|
||||||
|
shutil.copy2(script_path, ts_dir)
|
||||||
|
|
||||||
def print_loss_comparison(results_folders):
|
# Optionally mirror latest
|
||||||
"""Print comparison tables of original vs calculated losses for each architecture."""
|
latest = OUTPUT_DIR / "latest"
|
||||||
print("\nLoss Comparison Tables")
|
latest.mkdir(exist_ok=True, parents=True)
|
||||||
print("=" * 80)
|
for f in ts_dir.iterdir():
|
||||||
|
if f.is_file():
|
||||||
|
shutil.copy2(f, latest / f.name)
|
||||||
|
|
||||||
for arch_name, config in results_folders.items():
|
print(f"Saved plots to: {ts_dir}")
|
||||||
results_folder = config["path"]
|
print(f"Also updated: {latest}")
|
||||||
batch_size = config["batch_size"]
|
|
||||||
result_files = sorted(
|
|
||||||
results_folder.glob("ae_elbow_results_subter_*_kfold.pkl")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare table data
|
|
||||||
table_data = []
|
|
||||||
headers = ["Dimension", "Original", "Calculated", "Diff"]
|
|
||||||
|
|
||||||
for result_file in result_files:
|
|
||||||
with open(result_file, "rb") as f:
|
|
||||||
results = pickle.load(f)
|
|
||||||
|
|
||||||
dim = int(results["dimension"])
|
|
||||||
folds = results["ae_results"]
|
|
||||||
|
|
||||||
# Calculate mean original loss across folds
|
|
||||||
orig_losses = []
|
|
||||||
calc_losses = []
|
|
||||||
for fold_key in folds:
|
|
||||||
fold_data = folds[fold_key]["test"]
|
|
||||||
orig_losses.append(fold_data["loss"])
|
|
||||||
calc_losses.append(
|
|
||||||
calculate_batch_mean_loss(np.array(fold_data["scores"]), batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
orig_mean = np.mean(orig_losses)
|
|
||||||
calc_mean = np.mean(calc_losses)
|
|
||||||
diff = abs(orig_mean - calc_mean)
|
|
||||||
|
|
||||||
table_data.append([dim, orig_mean, calc_mean, diff])
|
|
||||||
|
|
||||||
# Sort by dimension
|
|
||||||
table_data.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
print(f"\n{arch_name}:")
|
|
||||||
print(
|
|
||||||
tabulate(
|
|
||||||
table_data,
|
|
||||||
headers=headers,
|
|
||||||
floatfmt=".6f",
|
|
||||||
tablefmt="pipe",
|
|
||||||
numalign="right",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Print loss comparisons for all architectures
|
main()
|
||||||
print_loss_comparison(results_folders)
|
|
||||||
|
|
||||||
# Run main analysis
|
|
||||||
evaluate_autoencoder_loss()
|
|
||||||
|
|
||||||
# Archive management
|
|
||||||
# Delete current latest folder
|
|
||||||
shutil.rmtree(latest_folder_path, ignore_errors=True)
|
|
||||||
latest_folder_path.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# Copy contents to latest folder
|
|
||||||
for file in output_datetime_path.iterdir():
|
|
||||||
shutil.copy2(file, latest_folder_path)
|
|
||||||
|
|
||||||
# Copy this script for reference
|
|
||||||
shutil.copy2(__file__, output_datetime_path)
|
|
||||||
shutil.copy2(__file__, latest_folder_path)
|
|
||||||
|
|
||||||
# Move output to archive
|
|
||||||
shutil.move(output_datetime_path, archive_folder_path)
|
|
||||||
|
|||||||
164
tools/plot_scripts/data_spherical_projection_as_trained.py
Normal file
164
tools/plot_scripts/data_spherical_projection_as_trained.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# --- Setup output folders ---
|
||||||
|
output_path = Path("/home/fedex/mt/plots/data_2d_projections_training")
|
||||||
|
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
output_datetime_path = output_path / datetime_folder_name
|
||||||
|
latest_folder_path = output_path / "latest"
|
||||||
|
archive_folder_path = output_path / "archive"
|
||||||
|
|
||||||
|
for folder in (
|
||||||
|
output_path,
|
||||||
|
output_datetime_path,
|
||||||
|
latest_folder_path,
|
||||||
|
archive_folder_path,
|
||||||
|
):
|
||||||
|
folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# --- Parse command-line arguments ---
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Plot two 2D projections as used in training (unstretched, grayscale)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input1",
|
||||||
|
type=Path,
|
||||||
|
default=Path(
|
||||||
|
"/home/fedex/mt/data/subter/new_projection/1_loop_closure_illuminated_2023-01-23.npy"
|
||||||
|
),
|
||||||
|
help="Path to first .npy file containing 2D projection data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input2",
|
||||||
|
type=Path,
|
||||||
|
default=Path(
|
||||||
|
"/home/fedex/mt/data/subter/new_projection/3_smoke_human_walking_2023-01-23.npy"
|
||||||
|
),
|
||||||
|
help="Path to second .npy file containing 2D projection data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame1",
|
||||||
|
type=int,
|
||||||
|
default=955,
|
||||||
|
help="Frame index to plot from first file (0-indexed)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame2",
|
||||||
|
type=int,
|
||||||
|
default=242,
|
||||||
|
help="Frame index to plot from second file (0-indexed)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# --- Load the numpy projection data ---
|
||||||
|
proj_data1 = np.load(args.input1)
|
||||||
|
proj_data2 = np.load(args.input2)
|
||||||
|
|
||||||
|
# Choose the desired frames
|
||||||
|
try:
|
||||||
|
frame1 = proj_data1[args.frame1]
|
||||||
|
frame2 = proj_data2[args.frame2]
|
||||||
|
except IndexError as e:
|
||||||
|
raise ValueError(f"Frame index out of range: {e}")
|
||||||
|
|
||||||
|
# Debug info: Print the percentage of missing data in each frame
|
||||||
|
print(f"Frame 1 missing data percentage: {np.isnan(frame1).mean() * 100:.2f}%")
|
||||||
|
print(f"Frame 2 missing data percentage: {np.isnan(frame2).mean() * 100:.2f}%")
|
||||||
|
|
||||||
|
# --- Create a figure with 2 vertical subplots ---
|
||||||
|
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10, 5))
|
||||||
|
|
||||||
|
# Create custom colormap for missing data visualization
|
||||||
|
missing_color = [1, 0, 0, 1] # Red with full alpha
|
||||||
|
cmap_missing = ListedColormap([missing_color])
|
||||||
|
|
||||||
|
# Replace the plotting section
|
||||||
|
for ax, frame, title in zip(
|
||||||
|
(ax1, ax2),
|
||||||
|
(frame1, frame2),
|
||||||
|
(
|
||||||
|
"Normal LiDAR Frame",
|
||||||
|
"Degraded LiDAR Frame (Smoke)",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Create mask for missing data (directly from NaN values)
|
||||||
|
missing_mask = np.isnan(frame)
|
||||||
|
|
||||||
|
# Plot the valid data in grayscale
|
||||||
|
frame_valid = np.copy(frame)
|
||||||
|
frame_valid[missing_mask] = 0 # Set missing values to black in base image
|
||||||
|
im = ax.imshow(frame_valid, cmap="gray", aspect="equal", vmin=0, vmax=0.8)
|
||||||
|
|
||||||
|
# Overlay missing data in red with reduced alpha
|
||||||
|
ax.imshow(
|
||||||
|
missing_mask,
|
||||||
|
cmap=ListedColormap([[1, 0, 0, 1]]), # Pure red
|
||||||
|
alpha=0.3, # Reduced alpha for better visibility
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
# Adjust layout
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Create a more informative legend
|
||||||
|
legend_elements = [
|
||||||
|
mpatches.Patch(facecolor="red", alpha=0.7, label="Missing Data"),
|
||||||
|
mpatches.Patch(facecolor="white", label="Close Distance (0m)"),
|
||||||
|
mpatches.Patch(facecolor="gray", label="Mid Distance"),
|
||||||
|
mpatches.Patch(facecolor="black", label="Far Distance (70m)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add legend with better positioning and formatting
|
||||||
|
fig.legend(
|
||||||
|
handles=legend_elements,
|
||||||
|
loc="center right",
|
||||||
|
bbox_to_anchor=(0.98, 0.5),
|
||||||
|
title="Distance Information",
|
||||||
|
framealpha=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
output_file = output_datetime_path / "data_2d_projections_training.png"
|
||||||
|
plt.savefig(output_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"Plot saved to: {output_file}")
|
||||||
|
|
||||||
|
# --- Create grayscale training images ---
|
||||||
|
for degradation_status, frame_number, frame in (
|
||||||
|
("normal", args.frame1, frame1),
|
||||||
|
("smoke", args.frame2, frame2),
|
||||||
|
):
|
||||||
|
frame_gray = np.nan_to_num(frame, nan=0).astype(np.float32)
|
||||||
|
gray_image = Image.fromarray(frame_gray, mode="F")
|
||||||
|
gray_output_file = (
|
||||||
|
output_datetime_path
|
||||||
|
/ f"frame_{frame_number}_training_{degradation_status}.tiff"
|
||||||
|
)
|
||||||
|
gray_image.save(gray_output_file)
|
||||||
|
print(f"Training image saved to: {gray_output_file}")
|
||||||
|
|
||||||
|
# --- Handle folder structure ---
|
||||||
|
shutil.rmtree(latest_folder_path, ignore_errors=True)
|
||||||
|
latest_folder_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
for file in output_datetime_path.iterdir():
|
||||||
|
shutil.copy2(file, latest_folder_path)
|
||||||
|
|
||||||
|
script_path = Path(__file__)
|
||||||
|
shutil.copy2(script_path, output_datetime_path)
|
||||||
|
shutil.copy2(script_path, latest_folder_path)
|
||||||
|
|
||||||
|
shutil.move(output_datetime_path, archive_folder_path)
|
||||||
|
print(f"Output archived to: {archive_folder_path}")
|
||||||
597
tools/plot_scripts/load_results.py
Normal file
597
tools/plot_scripts/load_results.py
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Config you can tweak
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
MODELS = ["deepsad", "isoforest", "ocsvm"]
|
||||||
|
EVALS = ["exp_based", "manual_based"]
|
||||||
|
|
||||||
|
SCHEMA_STATIC = {
|
||||||
|
# identifiers / dims
|
||||||
|
"network": pl.Utf8, # e.g. "LeNet", "efficient"
|
||||||
|
"latent_dim": pl.Int32,
|
||||||
|
"semi_normals": pl.Int32,
|
||||||
|
"semi_anomalous": pl.Int32,
|
||||||
|
"model": pl.Utf8, # "deepsad" | "isoforest" | "ocsvm"
|
||||||
|
"eval": pl.Utf8, # "exp_based" | "manual_based"
|
||||||
|
"fold": pl.Int32,
|
||||||
|
# metrics
|
||||||
|
"auc": pl.Float64,
|
||||||
|
"ap": pl.Float64,
|
||||||
|
# per-sample scores: list of (idx, label, score)
|
||||||
|
"scores": pl.List(
|
||||||
|
pl.Struct(
|
||||||
|
{
|
||||||
|
"sample_idx": pl.Int32, # dataloader idx
|
||||||
|
"orig_label": pl.Int8, # {-1,0,1}
|
||||||
|
"score": pl.Float64, # anomaly score
|
||||||
|
}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
# curves (normalized)
|
||||||
|
"roc_curve": pl.Struct(
|
||||||
|
{
|
||||||
|
"fpr": pl.List(pl.Float64),
|
||||||
|
"tpr": pl.List(pl.Float64),
|
||||||
|
"thr": pl.List(pl.Float64),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"prc_curve": pl.Struct(
|
||||||
|
{
|
||||||
|
"precision": pl.List(pl.Float64),
|
||||||
|
"recall": pl.List(pl.Float64),
|
||||||
|
"thr": pl.List(pl.Float64), # may be len(precision)-1
|
||||||
|
}
|
||||||
|
),
|
||||||
|
# deepsad-only per-eval arrays (None for other models)
|
||||||
|
"sample_indices": pl.List(pl.Int32),
|
||||||
|
"sample_labels": pl.List(pl.Int8),
|
||||||
|
"valid_mask": pl.List(pl.Boolean),
|
||||||
|
# timings / housekeeping
|
||||||
|
"train_time": pl.Float64,
|
||||||
|
"test_time": pl.Float64,
|
||||||
|
"folder": pl.Utf8,
|
||||||
|
"k_fold_num": pl.Int32,
|
||||||
|
"config_json": pl.Utf8, # full config.json as string (for reference)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pretraining-only (AE) schema
|
||||||
|
# Pretraining-only (AE) schema — lighter defaults
|
||||||
|
PRETRAIN_SCHEMA = {
|
||||||
|
# identifiers / dims
|
||||||
|
"network": pl.Utf8, # e.g. "LeNet", "efficient"
|
||||||
|
"latent_dim": pl.Int32,
|
||||||
|
"semi_normals": pl.Int32,
|
||||||
|
"semi_anomalous": pl.Int32,
|
||||||
|
"model": pl.Utf8, # always "ae"
|
||||||
|
"fold": pl.Int32,
|
||||||
|
"split": pl.Utf8, # "train" | "test"
|
||||||
|
# timings and optimization
|
||||||
|
"time": pl.Float64,
|
||||||
|
"loss": pl.Float64,
|
||||||
|
# per-sample arrays (as lists)
|
||||||
|
"indices": pl.List(pl.Int32),
|
||||||
|
"labels_exp_based": pl.List(pl.Int32),
|
||||||
|
"labels_manual_based": pl.List(pl.Int32),
|
||||||
|
"semi_targets": pl.List(pl.Int32),
|
||||||
|
"file_ids": pl.List(pl.Int32),
|
||||||
|
"frame_ids": pl.List(pl.Int32),
|
||||||
|
"scores": pl.List(pl.Float32), # <— use Float32 to match source and save space
|
||||||
|
# file id -> name mapping from the result dict
|
||||||
|
"file_names": pl.List(pl.Struct({"file_id": pl.Int32, "name": pl.Utf8})),
|
||||||
|
# housekeeping
|
||||||
|
"folder": pl.Utf8,
|
||||||
|
"k_fold_num": pl.Int32,
|
||||||
|
"config_json": pl.Utf8, # full config.json as string (for reference)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Helpers: curve/scores normalizers (tuples/ndarrays -> dict/list)
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def _tolist(x):
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return x.tolist()
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return list(x)
|
||||||
|
# best-effort scalar wrap
|
||||||
|
try:
|
||||||
|
return [x]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_float_list(a) -> Optional[List[float]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return [None if x is None else float(x) for x in a]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_file_names(d) -> Optional[List[dict]]:
|
||||||
|
"""
|
||||||
|
Convert the 'file_names' dict (keys like numpy.int64 -> str) to a
|
||||||
|
list[ {file_id:int, name:str} ], sorted by file_id.
|
||||||
|
"""
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return None
|
||||||
|
out: List[dict] = []
|
||||||
|
for k, v in d.items():
|
||||||
|
try:
|
||||||
|
file_id = int(k)
|
||||||
|
except Exception:
|
||||||
|
# keys are printed as np.int64 in the structure; best-effort cast
|
||||||
|
continue
|
||||||
|
out.append({"file_id": file_id, "name": str(v)})
|
||||||
|
out.sort(key=lambda x: x["file_id"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_roc(obj: Any) -> Optional[dict]:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
fpr = tpr = thr = None
|
||||||
|
if isinstance(obj, (tuple, list)):
|
||||||
|
if len(obj) >= 2:
|
||||||
|
fpr, tpr = _tolist(obj[0]), _tolist(obj[1])
|
||||||
|
if len(obj) >= 3:
|
||||||
|
thr = _tolist(obj[2])
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
fpr = _tolist(obj.get("fpr") or obj.get("x"))
|
||||||
|
tpr = _tolist(obj.get("tpr") or obj.get("y"))
|
||||||
|
thr = _tolist(obj.get("thr") or obj.get("thresholds"))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if fpr is None or tpr is None:
|
||||||
|
return None
|
||||||
|
return {"fpr": fpr, "tpr": tpr, "thr": thr}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_prc(obj: Any) -> Optional[dict]:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
precision = recall = thr = None
|
||||||
|
if isinstance(obj, (tuple, list)):
|
||||||
|
if len(obj) >= 2:
|
||||||
|
precision, recall = _tolist(obj[0]), _tolist(obj[1])
|
||||||
|
if len(obj) >= 3:
|
||||||
|
thr = _tolist(obj[2])
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
precision = _tolist(obj.get("precision") or obj.get("y"))
|
||||||
|
recall = _tolist(obj.get("recall") or obj.get("x"))
|
||||||
|
thr = _tolist(obj.get("thr") or obj.get("thresholds"))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if precision is None or recall is None:
|
||||||
|
return None
|
||||||
|
return {"precision": precision, "recall": recall, "thr": thr}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_scores_to_struct(seq) -> Optional[List[dict]]:
|
||||||
|
"""
|
||||||
|
Input: list of (idx, label, score) tuples (as produced in your test()).
|
||||||
|
Output: list of dicts with keys sample_idx, orig_label, score.
|
||||||
|
"""
|
||||||
|
if seq is None:
|
||||||
|
return None
|
||||||
|
if isinstance(seq, np.ndarray):
|
||||||
|
seq = seq.tolist()
|
||||||
|
if not isinstance(seq, (list, tuple)):
|
||||||
|
return None
|
||||||
|
out: List[dict] = []
|
||||||
|
for item in seq:
|
||||||
|
if isinstance(item, (list, tuple)) and len(item) >= 3:
|
||||||
|
idx, lab, sc = item[0], item[1], item[2]
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"sample_idx": None if idx is None else int(idx),
|
||||||
|
"orig_label": None if lab is None else int(lab),
|
||||||
|
"score": None if sc is None else float(sc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# fallback: single numeric -> score
|
||||||
|
sc = (
|
||||||
|
float(item)
|
||||||
|
if isinstance(item, (int, float, np.integer, np.floating))
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
out.append({"sample_idx": None, "orig_label": None, "score": sc})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_int_list(a) -> Optional[List[int]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return list(a)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_bool_list(a) -> Optional[List[bool]]:
|
||||||
|
if a is None:
|
||||||
|
return None
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = a.tolist()
|
||||||
|
return [bool(x) for x in a]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Low-level: read one experiment folder
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def read_config(exp_dir: Path) -> dict:
|
||||||
|
cfg = exp_dir / "config.json"
|
||||||
|
with cfg.open("r") as f:
|
||||||
|
c = json.load(f)
|
||||||
|
if not c.get("k_fold"):
|
||||||
|
raise ValueError(f"{exp_dir.name}: not trained as k-fold")
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def read_pickle(p: Path) -> Any:
|
||||||
|
with p.open("rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Extractors for each model
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def rows_from_deepsad(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
deepsad under data['test'][eval], with extra per-eval arrays and AP present.
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
test = data.get("test", {})
|
||||||
|
for ev in evals:
|
||||||
|
evd = test.get(ev)
|
||||||
|
if not isinstance(evd, dict):
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(evd["auc"])
|
||||||
|
if "auc" in evd and evd["auc"] is not None
|
||||||
|
else None,
|
||||||
|
"roc": normalize_roc(evd.get("roc")),
|
||||||
|
"prc": normalize_prc(evd.get("prc")),
|
||||||
|
"ap": float(evd["ap"]) if "ap" in evd and evd["ap"] is not None else None,
|
||||||
|
"scores": normalize_scores_to_struct(evd.get("scores")),
|
||||||
|
"sample_indices": normalize_int_list(evd.get("indices")),
|
||||||
|
"sample_labels": normalize_int_list(evd.get("labels")),
|
||||||
|
"valid_mask": normalize_bool_list(evd.get("valid_mask")),
|
||||||
|
"train_time": data.get("train", {}).get("time"),
|
||||||
|
"test_time": test.get("time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def rows_from_isoforest(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Keys: test_auc_<eval>, test_roc_<eval>, test_prc_<eval>, test_ap_<eval>, test_scores_<eval>.
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
for ev in evals:
|
||||||
|
auc = data.get(f"test_auc_{ev}")
|
||||||
|
if auc is None:
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(auc),
|
||||||
|
"roc": normalize_roc(data.get(f"test_roc_{ev}")),
|
||||||
|
"prc": normalize_prc(data.get(f"test_prc_{ev}")),
|
||||||
|
"ap": float(data.get(f"test_ap_{ev}"))
|
||||||
|
if data.get(f"test_ap_{ev}") is not None
|
||||||
|
else None,
|
||||||
|
"scores": normalize_scores_to_struct(data.get(f"test_scores_{ev}")),
|
||||||
|
"sample_indices": None,
|
||||||
|
"sample_labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
"train_time": data.get("train_time"),
|
||||||
|
"test_time": data.get("test_time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def rows_from_ocsvm_default(data: dict, evals: List[str]) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Default OCSVM only (ignore linear variant entirely).
|
||||||
|
"""
|
||||||
|
out: Dict[str, dict] = {}
|
||||||
|
for ev in evals:
|
||||||
|
auc = data.get(f"test_auc_{ev}")
|
||||||
|
if auc is None:
|
||||||
|
continue
|
||||||
|
out[ev] = {
|
||||||
|
"auc": float(auc),
|
||||||
|
"roc": normalize_roc(data.get(f"test_roc_{ev}")),
|
||||||
|
"prc": normalize_prc(data.get(f"test_prc_{ev}")),
|
||||||
|
"ap": float(data.get(f"test_ap_{ev}"))
|
||||||
|
if data.get(f"test_ap_{ev}") is not None
|
||||||
|
else None,
|
||||||
|
"scores": normalize_scores_to_struct(data.get(f"test_scores_{ev}")),
|
||||||
|
"sample_indices": None,
|
||||||
|
"sample_labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
"train_time": data.get("train_time"),
|
||||||
|
"test_time": data.get("test_time"),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Build the Polars DataFrame
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def load_results_dataframe(root: Path, allow_cache: bool = True) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Walks experiment subdirs under `root`. For each (model, fold) it adds rows:
|
||||||
|
Columns (SCHEMA_STATIC):
|
||||||
|
network, latent_dim, semi_normals, semi_anomalous,
|
||||||
|
model, eval, fold,
|
||||||
|
auc, ap, scores{sample_idx,orig_label,score},
|
||||||
|
roc_curve{fpr,tpr,thr}, prc_curve{precision,recall,thr},
|
||||||
|
sample_indices, sample_labels, valid_mask,
|
||||||
|
train_time, test_time,
|
||||||
|
folder, k_fold_num
|
||||||
|
"""
|
||||||
|
if allow_cache:
|
||||||
|
cache = root / "results_cache.parquet"
|
||||||
|
if cache.exists():
|
||||||
|
try:
|
||||||
|
df = pl.read_parquet(cache)
|
||||||
|
print(f"[info] loaded cached results frame from {cache}")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to load cache {cache}: {e}")
|
||||||
|
|
||||||
|
rows: List[dict] = []
|
||||||
|
|
||||||
|
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
|
||||||
|
for exp_dir in sorted(exp_dirs):
|
||||||
|
try:
|
||||||
|
cfg = read_config(exp_dir)
|
||||||
|
cfg_json = json.dumps(cfg, sort_keys=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] skipping {exp_dir.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
network = cfg.get("net_name")
|
||||||
|
latent_dim = int(cfg.get("latent_space_dim"))
|
||||||
|
semi_normals = int(cfg.get("num_known_normal"))
|
||||||
|
semi_anomalous = int(cfg.get("num_known_outlier"))
|
||||||
|
k = int(cfg.get("k_fold_num"))
|
||||||
|
|
||||||
|
for model in MODELS:
|
||||||
|
for fold in range(k):
|
||||||
|
pkl = exp_dir / f"results_{model}_{fold}.pkl"
|
||||||
|
if not pkl.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = read_pickle(pkl)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to read {pkl.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model == "deepsad":
|
||||||
|
per_eval = rows_from_deepsad(data, EVALS) # eval -> dict
|
||||||
|
elif model == "isoforest":
|
||||||
|
per_eval = rows_from_isoforest(data, EVALS) # eval -> dict
|
||||||
|
elif model == "ocsvm":
|
||||||
|
per_eval = rows_from_ocsvm_default(data, EVALS) # eval -> dict
|
||||||
|
else:
|
||||||
|
per_eval = {}
|
||||||
|
|
||||||
|
for ev, vals in per_eval.items():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"network": network,
|
||||||
|
"latent_dim": latent_dim,
|
||||||
|
"semi_normals": semi_normals,
|
||||||
|
"semi_anomalous": semi_anomalous,
|
||||||
|
"model": model,
|
||||||
|
"eval": ev,
|
||||||
|
"fold": fold,
|
||||||
|
"auc": vals["auc"],
|
||||||
|
"ap": vals["ap"],
|
||||||
|
"scores": vals["scores"],
|
||||||
|
"roc_curve": vals["roc"],
|
||||||
|
"prc_curve": vals["prc"],
|
||||||
|
"sample_indices": vals.get("sample_indices"),
|
||||||
|
"sample_labels": vals.get("sample_labels"),
|
||||||
|
"valid_mask": vals.get("valid_mask"),
|
||||||
|
"train_time": vals["train_time"],
|
||||||
|
"test_time": vals["test_time"],
|
||||||
|
"folder": str(exp_dir),
|
||||||
|
"k_fold_num": k,
|
||||||
|
"config_json": cfg_json,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# If empty, return a typed empty frame
|
||||||
|
if not rows:
|
||||||
|
return pl.DataFrame(schema=SCHEMA_STATIC)
|
||||||
|
|
||||||
|
df = pl.DataFrame(rows, schema=SCHEMA_STATIC)
|
||||||
|
|
||||||
|
# Cast to efficient dtypes (categoricals etc.) – no extra sanitation
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("network", "model", "eval").cast(pl.Categorical),
|
||||||
|
pl.col(
|
||||||
|
"latent_dim", "semi_normals", "semi_anomalous", "fold", "k_fold_num"
|
||||||
|
).cast(pl.Int32),
|
||||||
|
pl.col("auc", "ap", "train_time", "test_time").cast(pl.Float64),
|
||||||
|
# NOTE: no cast on 'scores' here; it's already List(Struct) per schema.
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_cache:
|
||||||
|
try:
|
||||||
|
df.write_parquet(cache)
|
||||||
|
print(f"[info] cached results frame to {cache}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to write cache {cache}: {e}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretraining_results_dataframe(
|
||||||
|
root: Path,
|
||||||
|
allow_cache: bool = True,
|
||||||
|
include_train: bool = False, # <— default: store only TEST to keep cache tiny
|
||||||
|
keep_file_names: bool = False, # <— drop file_names by default; they’re repeated
|
||||||
|
parquet_compression: str = "zstd",
|
||||||
|
parquet_compression_level: int = 7, # <— stronger compression than default
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Loads only AE pretraining results: files named `results_ae_<fold>.pkl`.
|
||||||
|
Produces one row per (experiment, fold, split). By default we:
|
||||||
|
- include only the TEST split (include_train=False)
|
||||||
|
- store scores as Float32
|
||||||
|
- drop the repeated file_names mapping to save space
|
||||||
|
- write Parquet with zstd(level=7)
|
||||||
|
"""
|
||||||
|
if allow_cache:
|
||||||
|
cache = root / "pretraining_results_cache.parquet"
|
||||||
|
if cache.exists():
|
||||||
|
try:
|
||||||
|
df = pl.read_parquet(cache)
|
||||||
|
print(f"[info] loaded cached pretraining frame from {cache}")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to load pretraining cache {cache}: {e}")
|
||||||
|
|
||||||
|
rows: List[dict] = []
|
||||||
|
|
||||||
|
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
|
||||||
|
for exp_dir in sorted(exp_dirs):
|
||||||
|
try:
|
||||||
|
cfg = read_config(exp_dir)
|
||||||
|
cfg_json = json.dumps(cfg, sort_keys=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] skipping {exp_dir.name} (pretraining): {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
network = cfg.get("net_name")
|
||||||
|
latent_dim = int(cfg.get("latent_space_dim"))
|
||||||
|
semi_normals = int(cfg.get("num_known_normal"))
|
||||||
|
semi_anomalous = int(cfg.get("num_known_outlier"))
|
||||||
|
k = int(cfg.get("k_fold_num"))
|
||||||
|
|
||||||
|
# Only test split by default (include_train=False)
|
||||||
|
splits = ("train", "test") if include_train else ("test",)
|
||||||
|
|
||||||
|
for fold in range(k):
|
||||||
|
pkl = exp_dir / f"results_ae_{fold}.pkl"
|
||||||
|
if not pkl.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = read_pickle(pkl) # expected: {"train": {...}, "test": {...}}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to read {pkl.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
splitd = data.get(split)
|
||||||
|
if not isinstance(splitd, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"network": network,
|
||||||
|
"latent_dim": latent_dim,
|
||||||
|
"semi_normals": semi_normals,
|
||||||
|
"semi_anomalous": semi_anomalous,
|
||||||
|
"model": "ae",
|
||||||
|
"fold": fold,
|
||||||
|
"split": split,
|
||||||
|
"time": float(splitd.get("time"))
|
||||||
|
if splitd.get("time") is not None
|
||||||
|
else None,
|
||||||
|
"loss": float(splitd.get("loss"))
|
||||||
|
if splitd.get("loss") is not None
|
||||||
|
else None,
|
||||||
|
# ints as Int32, scores as Float32 to save space
|
||||||
|
"indices": normalize_int_list(splitd.get("indices")),
|
||||||
|
"labels_exp_based": normalize_int_list(
|
||||||
|
splitd.get("labels_exp_based")
|
||||||
|
),
|
||||||
|
"labels_manual_based": normalize_int_list(
|
||||||
|
splitd.get("labels_manual_based")
|
||||||
|
),
|
||||||
|
"semi_targets": normalize_int_list(splitd.get("semi_targets")),
|
||||||
|
"file_ids": normalize_int_list(splitd.get("file_ids")),
|
||||||
|
"frame_ids": normalize_int_list(splitd.get("frame_ids")),
|
||||||
|
"scores": (
|
||||||
|
None
|
||||||
|
if splitd.get("scores") is None
|
||||||
|
else [
|
||||||
|
float(x)
|
||||||
|
for x in (
|
||||||
|
splitd["scores"].tolist()
|
||||||
|
if isinstance(splitd["scores"], np.ndarray)
|
||||||
|
else splitd["scores"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"file_names": normalize_file_names(splitd.get("file_names"))
|
||||||
|
if keep_file_names
|
||||||
|
else None,
|
||||||
|
"folder": str(exp_dir),
|
||||||
|
"k_fold_num": k,
|
||||||
|
"config_json": cfg_json,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return pl.DataFrame(schema=PRETRAIN_SCHEMA)
|
||||||
|
|
||||||
|
df = pl.DataFrame(rows, schema=PRETRAIN_SCHEMA)
|
||||||
|
|
||||||
|
# Cast/optimize a bit (categoricals, ints, floats)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("network", "model", "split").cast(pl.Categorical),
|
||||||
|
pl.col(
|
||||||
|
"latent_dim", "semi_normals", "semi_anomalous", "fold", "k_fold_num"
|
||||||
|
).cast(pl.Int32),
|
||||||
|
pl.col("time", "loss").cast(pl.Float64),
|
||||||
|
pl.col("scores").cast(pl.List(pl.Float32)), # ensure downcast took
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_cache:
|
||||||
|
try:
|
||||||
|
cache = root / "pretraining_results_cache.parquet"
|
||||||
|
df.write_parquet(
|
||||||
|
cache,
|
||||||
|
compression=parquet_compression,
|
||||||
|
compression_level=parquet_compression_level,
|
||||||
|
statistics=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[info] cached pretraining frame to {cache} "
|
||||||
|
f"({parquet_compression}, level={parquet_compression_level})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] failed to write pretraining cache {cache}: {e}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root = Path("/home/fedex/mt/results/done")
|
||||||
|
df = load_results_dataframe(root, allow_cache=True)
|
||||||
|
print(df.shape, df.head())
|
||||||
|
|
||||||
|
df_pre = load_pretraining_results_dataframe(root, allow_cache=True)
|
||||||
|
print("pretraining:", df_pre.shape, df_pre.head())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
358
tools/plot_scripts/results_latent_space_comparisons.py
Normal file
358
tools/plot_scripts/results_latent_space_comparisons.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
|
||||||
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
|
||||||
|
from load_results import load_results_dataframe
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Config
|
||||||
|
# ----------------------------
|
||||||
|
ROOT = Path("/home/fedex/mt/results/done") # experiments root you pass to the loader
|
||||||
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_comparisons")
|
||||||
|
|
||||||
|
SEMI_LABELING_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
||||||
|
|
||||||
|
# Semi-supervised setting to select
|
||||||
|
SEMI_NORMALS = 50
|
||||||
|
SEMI_ANOMALOUS = 10
|
||||||
|
|
||||||
|
# Which evaluation columns to plot
|
||||||
|
EVALS = ["exp_based", "manual_based"]
|
||||||
|
|
||||||
|
# Latent dimensions to show as 7 subplots
|
||||||
|
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
||||||
|
|
||||||
|
# Interpolation grids
|
||||||
|
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
||||||
|
PRC_GRID = np.linspace(0.0, 1.0, 200)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Helpers
|
||||||
|
# ----------------------------
|
||||||
|
def canonicalize_network(name: str) -> str:
|
||||||
|
"""Map net_name strings to clean labels for plotting."""
|
||||||
|
low = (name or "").lower()
|
||||||
|
if "lenet" in low:
|
||||||
|
return "LeNet"
|
||||||
|
if "efficient" in low:
|
||||||
|
return "Efficient"
|
||||||
|
return name or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray):
|
||||||
|
"""
|
||||||
|
Interpolate a list of (x, y) curves onto a common grid.
|
||||||
|
Returns mean_y, std_y on the grid. Skips empty or invalid curves.
|
||||||
|
"""
|
||||||
|
if not curves:
|
||||||
|
return np.full_like(grid, np.nan, dtype=float), np.full_like(
|
||||||
|
grid, np.nan, dtype=float
|
||||||
|
)
|
||||||
|
|
||||||
|
interps = []
|
||||||
|
for x, y in curves:
|
||||||
|
if x is None or y is None:
|
||||||
|
continue
|
||||||
|
x = np.asarray(x, dtype=float)
|
||||||
|
y = np.asarray(y, dtype=float)
|
||||||
|
if x.size == 0 or y.size == 0 or x.size != y.size:
|
||||||
|
continue
|
||||||
|
# ensure sorted by x and unique
|
||||||
|
order = np.argsort(x)
|
||||||
|
x = x[order]
|
||||||
|
y = y[order]
|
||||||
|
# deduplicate x (np.interp requires ascending x)
|
||||||
|
uniq_x, uniq_idx = np.unique(x, return_index=True)
|
||||||
|
y = y[uniq_idx]
|
||||||
|
x = uniq_x
|
||||||
|
# bound grid to valid interp range
|
||||||
|
gmin = max(grid[0], x[0])
|
||||||
|
gmax = min(grid[-1], x[-1])
|
||||||
|
g = np.clip(grid, gmin, gmax)
|
||||||
|
yi = np.interp(g, x, y)
|
||||||
|
interps.append(yi)
|
||||||
|
|
||||||
|
if not interps:
|
||||||
|
return np.full_like(grid, np.nan, dtype=float), np.full_like(
|
||||||
|
grid, np.nan, dtype=float
|
||||||
|
)
|
||||||
|
|
||||||
|
A = np.vstack(interps)
|
||||||
|
return np.nanmean(A, axis=0), np.nanstd(A, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _net_label_col(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""Add 'net_label' column (LeNet/Efficient/fallback)."""
|
||||||
|
return df.with_columns(
|
||||||
|
pl.when(
|
||||||
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
||||||
|
)
|
||||||
|
.then(pl.lit("LeNet"))
|
||||||
|
.when(
|
||||||
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
||||||
|
)
|
||||||
|
.then(pl.lit("Efficient"))
|
||||||
|
.otherwise(pl.col("network").cast(pl.Utf8))
|
||||||
|
.alias("net_label")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_rows(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
eval_type: str,
|
||||||
|
latent_dim: int,
|
||||||
|
net_label: str | None,
|
||||||
|
semi_normals: int,
|
||||||
|
semi_anomalous: int,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""Polars filter: by model/eval/latent and optionally net_label."""
|
||||||
|
exprs = [
|
||||||
|
pl.col("model") == model,
|
||||||
|
pl.col("eval") == eval_type,
|
||||||
|
pl.col("latent_dim") == latent_dim,
|
||||||
|
pl.col("semi_normals") == semi_normals,
|
||||||
|
pl.col("semi_anomalous") == semi_anomalous,
|
||||||
|
]
|
||||||
|
if net_label is not None:
|
||||||
|
exprs.append(pl.col("net_label") == net_label)
|
||||||
|
return df.filter(pl.all_horizontal(exprs))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
From a list of rows (Python dicts), return list of (x, y) curves for given kind.
|
||||||
|
kind: "roc" or "prc"
|
||||||
|
"""
|
||||||
|
curves = []
|
||||||
|
for r in rows:
|
||||||
|
if kind == "roc":
|
||||||
|
c = r.get("roc_curve")
|
||||||
|
if not c:
|
||||||
|
continue
|
||||||
|
x, y = c.get("fpr"), c.get("tpr")
|
||||||
|
else:
|
||||||
|
c = r.get("prc_curve")
|
||||||
|
if not c:
|
||||||
|
continue
|
||||||
|
x, y = c.get("recall"), c.get("precision")
|
||||||
|
if x is None or y is None:
|
||||||
|
continue
|
||||||
|
curves.append((np.asarray(x, dtype=float), np.asarray(y, dtype=float)))
|
||||||
|
return curves
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_dim_axes(fig_title: str):
|
||||||
|
"""Return figure, axes array laid out 2x4; last axis is for legend."""
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
nrows=4, ncols=2, figsize=(12, 16), constrained_layout=True
|
||||||
|
)
|
||||||
|
fig.suptitle(fig_title, fontsize=14)
|
||||||
|
axes = axes.ravel()
|
||||||
|
return fig, axes
|
||||||
|
|
||||||
|
|
||||||
|
def _add_legend_to_axis(ax, handles_labels):
|
||||||
|
ax.axis("off")
|
||||||
|
handles, labels = handles_labels
|
||||||
|
ax.legend(
|
||||||
|
handles,
|
||||||
|
labels,
|
||||||
|
loc="center",
|
||||||
|
frameon=False,
|
||||||
|
ncol=1,
|
||||||
|
fontsize=11,
|
||||||
|
borderaxespad=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_grid_from_df(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
eval_type: str,
|
||||||
|
kind: str,
|
||||||
|
semi_normals: int,
|
||||||
|
semi_anomalous: int,
|
||||||
|
out_path: Path,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a 2x4 grid of subplots, one per latent dim; 8th panel holds legend.
|
||||||
|
kind: 'roc' or 'prc'
|
||||||
|
"""
|
||||||
|
fig_title = f"{kind.upper()} — {eval_type} (semi = {semi_normals}/{semi_anomalous})"
|
||||||
|
fig, axes = _ensure_dim_axes(fig_title)
|
||||||
|
|
||||||
|
# plotting order & colors
|
||||||
|
series = [
|
||||||
|
(
|
||||||
|
"isoforest",
|
||||||
|
None,
|
||||||
|
"IsolationForest",
|
||||||
|
"tab:purple",
|
||||||
|
), # baselines from Efficient only (handled below)
|
||||||
|
("ocsvm", None, "OC-SVM", "tab:green"),
|
||||||
|
("deepsad", "LeNet", "DeepSAD (LeNet)", "tab:blue"),
|
||||||
|
("deepsad", "Efficient", "DeepSAD (Efficient)", "tab:orange"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Handles for legend (build from first subplot that has data)
|
||||||
|
legend_handles = []
|
||||||
|
legend_labels = []
|
||||||
|
have_legend = False
|
||||||
|
|
||||||
|
for i, dim in enumerate(LATENT_DIMS):
|
||||||
|
if i >= 7:
|
||||||
|
break # last slot reserved for legend
|
||||||
|
ax = axes[i]
|
||||||
|
ax.set_title(f"latent_dim = {dim}")
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
if kind == "roc":
|
||||||
|
ax.set_xlim(0, 1)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
ax.set_xlabel("FPR")
|
||||||
|
ax.set_ylabel("TPR")
|
||||||
|
grid = ROC_GRID
|
||||||
|
else:
|
||||||
|
ax.set_xlim(0, 1)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
grid = PRC_GRID
|
||||||
|
|
||||||
|
plotted_any = False
|
||||||
|
|
||||||
|
for model, net_needed, label, color in series:
|
||||||
|
# baselines: use Efficient only
|
||||||
|
net_filter = net_needed
|
||||||
|
if model in ("isoforest", "ocsvm"):
|
||||||
|
net_filter = "Efficient"
|
||||||
|
|
||||||
|
sub = _select_rows(
|
||||||
|
df,
|
||||||
|
model=model,
|
||||||
|
eval_type=eval_type,
|
||||||
|
latent_dim=dim,
|
||||||
|
net_label=net_filter,
|
||||||
|
semi_normals=semi_normals,
|
||||||
|
semi_anomalous=semi_anomalous,
|
||||||
|
)
|
||||||
|
if sub.height == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rows = sub.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
||||||
|
|
||||||
|
curves = _extract_curves(rows, kind)
|
||||||
|
if not curves:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
||||||
|
# Guard for all-NaN
|
||||||
|
if np.all(np.isnan(mean_y)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
ax.plot(grid, mean_y, label=label, color=color)
|
||||||
|
ax.fill_between(
|
||||||
|
grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color
|
||||||
|
)
|
||||||
|
plotted_any = True
|
||||||
|
|
||||||
|
if not have_legend:
|
||||||
|
legend_handles.append(Line2D([0], [0], color=color, lw=2))
|
||||||
|
legend_labels.append(label)
|
||||||
|
|
||||||
|
if not plotted_any:
|
||||||
|
ax.text(
|
||||||
|
0.5, 0.5, "No data", ha="center", va="center", fontsize=12, alpha=0.7
|
||||||
|
)
|
||||||
|
ax.set_xlim(0, 1)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
|
||||||
|
if not have_legend and legend_handles:
|
||||||
|
have_legend = True
|
||||||
|
|
||||||
|
# Legend in 8th slot
|
||||||
|
_add_legend_to_axis(axes[7], (legend_handles, legend_labels))
|
||||||
|
|
||||||
|
# Save
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load main results DF (uses your cache if enabled in the loader)
|
||||||
|
df = load_results_dataframe(ROOT, allow_cache=True)
|
||||||
|
|
||||||
|
# Add clean network labels
|
||||||
|
complete_df = _net_label_col(df)
|
||||||
|
|
||||||
|
# Prepare output dirs
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
archive_dir = OUTPUT_DIR / "archive"
|
||||||
|
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES:
|
||||||
|
# Restrict to our semi-supervised setting
|
||||||
|
df = complete_df.filter(
|
||||||
|
(pl.col("semi_normals") == semi_normals)
|
||||||
|
& (pl.col("semi_anomalous") == semi_anomalous)
|
||||||
|
& (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
|
||||||
|
& (pl.col("eval").is_in(EVALS))
|
||||||
|
& (pl.col("latent_dim").is_in(LATENT_DIMS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot 4 figures
|
||||||
|
for eval_type in EVALS:
|
||||||
|
# ROC
|
||||||
|
plot_grid_from_df(
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
kind="roc",
|
||||||
|
semi_normals=semi_normals,
|
||||||
|
semi_anomalous=semi_anomalous,
|
||||||
|
out_path=ts_dir
|
||||||
|
/ f"roc_semi_{semi_normals}_{semi_anomalous}_{eval_type}.png",
|
||||||
|
)
|
||||||
|
# PRC
|
||||||
|
plot_grid_from_df(
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
kind="prc",
|
||||||
|
semi_normals=semi_normals,
|
||||||
|
semi_anomalous=semi_anomalous,
|
||||||
|
out_path=ts_dir
|
||||||
|
/ f"prc_{semi_normals}_{semi_anomalous}_{eval_type}.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy this script to preserve the code used for the outputs
|
||||||
|
script_path = Path(__file__)
|
||||||
|
shutil.copy2(script_path, ts_dir)
|
||||||
|
|
||||||
|
# Mirror latest
|
||||||
|
latest = OUTPUT_DIR / "latest"
|
||||||
|
latest.mkdir(exist_ok=True, parents=True)
|
||||||
|
for f in latest.iterdir():
|
||||||
|
if f.is_file():
|
||||||
|
f.unlink()
|
||||||
|
for f in ts_dir.iterdir():
|
||||||
|
if f.is_file():
|
||||||
|
shutil.copy2(f, latest / f.name)
|
||||||
|
|
||||||
|
print(f"Saved plots to: {ts_dir}")
|
||||||
|
print(f"Also updated: {latest}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
363
tools/plot_scripts/results_semi_labels_comparison.py
Normal file
363
tools/plot_scripts/results_semi_labels_comparison.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
# curves_2x1_by_net_with_regimes_from_df.py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
from scipy.stats import sem, t
|
||||||
|
|
||||||
|
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
|
||||||
|
from load_results import load_results_dataframe
|
||||||
|
|
||||||
|
# ---------------------------------
|
||||||
|
# Config
|
||||||
|
# ---------------------------------
|
||||||
|
ROOT = Path("/home/fedex/mt/results/done")
|
||||||
|
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
|
||||||
|
|
||||||
|
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
||||||
|
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
||||||
|
EVALS = ["exp_based", "manual_based"]
|
||||||
|
|
||||||
|
# Interp grids
|
||||||
|
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
||||||
|
PRC_GRID = np.linspace(0.0, 1.0, 200)
|
||||||
|
|
||||||
|
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
|
||||||
|
BASELINE_NET = "Efficient"
|
||||||
|
|
||||||
|
# Colors/styles
|
||||||
|
COLOR_BASELINES = {
|
||||||
|
"isoforest": "tab:purple",
|
||||||
|
"ocsvm": "tab:green",
|
||||||
|
}
|
||||||
|
COLOR_REGIMES = {
|
||||||
|
(0, 0): "tab:blue",
|
||||||
|
(50, 10): "tab:orange",
|
||||||
|
(500, 100): "tab:red",
|
||||||
|
}
|
||||||
|
LINESTYLES = {
|
||||||
|
(0, 0): "-",
|
||||||
|
(50, 10): "--",
|
||||||
|
(500, 100): "-.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------
|
||||||
|
def _net_label_col(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
return df.with_columns(
|
||||||
|
pl.when(
|
||||||
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
|
||||||
|
)
|
||||||
|
.then(pl.lit("LeNet"))
|
||||||
|
.when(
|
||||||
|
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
|
||||||
|
)
|
||||||
|
.then(pl.lit("Efficient"))
|
||||||
|
.otherwise(pl.col("network").cast(pl.Utf8))
|
||||||
|
.alias("net_label")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_ci(values: list[float], confidence: float = 0.95) -> tuple[float, float]:
|
||||||
|
"""Return mean and half-width of the (approx) confidence interval. Robust to n<2."""
|
||||||
|
arr = np.asarray([v for v in values if v is not None], dtype=float)
|
||||||
|
if arr.size == 0:
|
||||||
|
return np.nan, np.nan
|
||||||
|
if arr.size == 1:
|
||||||
|
return float(arr[0]), 0.0
|
||||||
|
m = float(arr.mean())
|
||||||
|
s = sem(arr, nan_policy="omit")
|
||||||
|
h = s * t.ppf((1 + confidence) / 2.0, arr.size - 1)
|
||||||
|
return m, float(h)
|
||||||
|
|
||||||
|
|
||||||
|
def _interp_mean_std(curves: list[tuple[np.ndarray, np.ndarray]], grid: np.ndarray):
|
||||||
|
"""Interpolate many (x,y) onto grid and return mean±std; robust to duplicates/empty."""
|
||||||
|
if not curves:
|
||||||
|
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
|
||||||
|
interps = []
|
||||||
|
for x, y in curves:
|
||||||
|
if x is None or y is None:
|
||||||
|
continue
|
||||||
|
x = np.asarray(x, float)
|
||||||
|
y = np.asarray(y, float)
|
||||||
|
if x.size == 0 or y.size == 0 or x.size != y.size:
|
||||||
|
continue
|
||||||
|
order = np.argsort(x)
|
||||||
|
x, y = x[order], y[order]
|
||||||
|
x, uniq_idx = np.unique(x, return_index=True)
|
||||||
|
y = y[uniq_idx]
|
||||||
|
g = np.clip(grid, x[0], x[-1])
|
||||||
|
yi = np.interp(g, x, y)
|
||||||
|
interps.append(yi)
|
||||||
|
if not interps:
|
||||||
|
return np.full_like(grid, np.nan, float), np.full_like(grid, np.nan, float)
|
||||||
|
A = np.vstack(interps)
|
||||||
|
return np.nanmean(A, axis=0), np.nanstd(A, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_curves(rows: list[dict], kind: str) -> list[tuple[np.ndarray, np.ndarray]]:
|
||||||
|
curves = []
|
||||||
|
for r in rows:
|
||||||
|
if kind == "roc":
|
||||||
|
c = r.get("roc_curve")
|
||||||
|
if not c:
|
||||||
|
continue
|
||||||
|
x, y = c.get("fpr"), c.get("tpr")
|
||||||
|
else:
|
||||||
|
c = r.get("prc_curve")
|
||||||
|
if not c:
|
||||||
|
continue
|
||||||
|
x, y = c.get("recall"), c.get("precision")
|
||||||
|
if x is None or y is None:
|
||||||
|
continue
|
||||||
|
curves.append((np.asarray(x, float), np.asarray(y, float)))
|
||||||
|
return curves
|
||||||
|
|
||||||
|
|
||||||
|
def _select_rows(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
eval_type: str,
|
||||||
|
latent_dim: int,
|
||||||
|
semi_normals: int | None = None,
|
||||||
|
semi_anomalous: int | None = None,
|
||||||
|
net_label: str | None = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
exprs = [
|
||||||
|
pl.col("model") == model,
|
||||||
|
pl.col("eval") == eval_type,
|
||||||
|
pl.col("latent_dim") == latent_dim,
|
||||||
|
]
|
||||||
|
if semi_normals is not None:
|
||||||
|
exprs.append(pl.col("semi_normals") == semi_normals)
|
||||||
|
if semi_anomalous is not None:
|
||||||
|
exprs.append(pl.col("semi_anomalous") == semi_anomalous)
|
||||||
|
if net_label is not None:
|
||||||
|
exprs.append(pl.col("net_label") == net_label)
|
||||||
|
return df.filter(pl.all_horizontal(exprs))
|
||||||
|
|
||||||
|
|
||||||
|
def _auc_list(sub: pl.DataFrame) -> list[float]:
|
||||||
|
return [x for x in sub.select("auc").to_series().to_list() if x is not None]
|
||||||
|
|
||||||
|
|
||||||
|
def _ap_list(sub: pl.DataFrame) -> list[float]:
|
||||||
|
return [x for x in sub.select("ap").to_series().to_list() if x is not None]
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_panel(
|
||||||
|
ax,
|
||||||
|
df: pl.DataFrame,
|
||||||
|
*,
|
||||||
|
eval_type: str,
|
||||||
|
net_for_deepsad: str,
|
||||||
|
latent_dim: int,
|
||||||
|
kind: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient).
|
||||||
|
Legend entries include mean±CI of AUC/AP.
|
||||||
|
"""
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xlim(0, 1)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
if kind == "roc":
|
||||||
|
ax.set_xlabel("FPR")
|
||||||
|
ax.set_ylabel("TPR")
|
||||||
|
grid = ROC_GRID
|
||||||
|
else:
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
grid = PRC_GRID
|
||||||
|
|
||||||
|
handles, labels = [], []
|
||||||
|
|
||||||
|
# ----- Baselines (Efficient)
|
||||||
|
for model in ("isoforest", "ocsvm"):
|
||||||
|
sub_b = _select_rows(
|
||||||
|
df,
|
||||||
|
model=model,
|
||||||
|
eval_type=eval_type,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
net_label=BASELINE_NET,
|
||||||
|
)
|
||||||
|
if sub_b.height == 0:
|
||||||
|
continue
|
||||||
|
rows = sub_b.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
||||||
|
curves = _extract_curves(rows, kind)
|
||||||
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
||||||
|
if np.all(np.isnan(mean_y)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Metric for legend
|
||||||
|
metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b)
|
||||||
|
m, ci = mean_ci(metric_vals)
|
||||||
|
lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||||
|
|
||||||
|
color = COLOR_BASELINES[model]
|
||||||
|
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
|
||||||
|
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
|
||||||
|
handles.append(h)
|
||||||
|
labels.append(lab)
|
||||||
|
|
||||||
|
# ----- DeepSAD (this panel's net) across semi-regimes
|
||||||
|
for regime in SEMI_REGIMES:
|
||||||
|
sn, sa = regime
|
||||||
|
sub_d = _select_rows(
|
||||||
|
df,
|
||||||
|
model="deepsad",
|
||||||
|
eval_type=eval_type,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
semi_normals=sn,
|
||||||
|
semi_anomalous=sa,
|
||||||
|
net_label=net_for_deepsad,
|
||||||
|
)
|
||||||
|
if sub_d.height == 0:
|
||||||
|
continue
|
||||||
|
rows = sub_d.select("roc_curve" if kind == "roc" else "prc_curve").to_dicts()
|
||||||
|
curves = _extract_curves(rows, kind)
|
||||||
|
mean_y, std_y = _interp_mean_std(curves, grid)
|
||||||
|
if np.all(np.isnan(mean_y)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d)
|
||||||
|
m, ci = mean_ci(metric_vals)
|
||||||
|
lab = f"DeepSAD {net_for_deepsad} — semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||||
|
|
||||||
|
color = COLOR_REGIMES[regime]
|
||||||
|
ls = LINESTYLES[regime]
|
||||||
|
h = ax.plot(grid, mean_y, lw=2, color=color, linestyle=ls, label=lab)[0]
|
||||||
|
ax.fill_between(grid, mean_y - std_y, mean_y + std_y, alpha=0.15, color=color)
|
||||||
|
handles.append(h)
|
||||||
|
labels.append(lab)
|
||||||
|
|
||||||
|
# Chance line for ROC
|
||||||
|
if kind == "roc":
|
||||||
|
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
|
||||||
|
|
||||||
|
# Legend
|
||||||
|
ax.legend(loc="lower right", fontsize=9, frameon=True)
|
||||||
|
|
||||||
|
|
||||||
|
def make_figures_for_dim(
|
||||||
|
df: pl.DataFrame, eval_type: str, latent_dim: int, out_dir: Path
|
||||||
|
):
|
||||||
|
# ROC: 2×1
|
||||||
|
fig_roc, axes = plt.subplots(
|
||||||
|
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||||
|
)
|
||||||
|
fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||||
|
|
||||||
|
_plot_panel(
|
||||||
|
axes[0],
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
net_for_deepsad="LeNet",
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
kind="roc",
|
||||||
|
)
|
||||||
|
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||||
|
|
||||||
|
_plot_panel(
|
||||||
|
axes[1],
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
net_for_deepsad="Efficient",
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
kind="roc",
|
||||||
|
)
|
||||||
|
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||||
|
|
||||||
|
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
|
||||||
|
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close(fig_roc)
|
||||||
|
|
||||||
|
# PRC: 2×1
|
||||||
|
fig_prc, axes = plt.subplots(
|
||||||
|
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||||
|
)
|
||||||
|
fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||||
|
|
||||||
|
_plot_panel(
|
||||||
|
axes[0],
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
net_for_deepsad="LeNet",
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
kind="prc",
|
||||||
|
)
|
||||||
|
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||||
|
|
||||||
|
_plot_panel(
|
||||||
|
axes[1],
|
||||||
|
df,
|
||||||
|
eval_type=eval_type,
|
||||||
|
net_for_deepsad="Efficient",
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
kind="prc",
|
||||||
|
)
|
||||||
|
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||||
|
|
||||||
|
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
|
||||||
|
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close(fig_prc)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load dataframe and prep
|
||||||
|
df = load_results_dataframe(ROOT, allow_cache=True)
|
||||||
|
df = _net_label_col(df)
|
||||||
|
|
||||||
|
# Filter to relevant models/evals only once
|
||||||
|
df = df.filter(
|
||||||
|
(pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
|
||||||
|
& (pl.col("eval").is_in(EVALS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output/archiving like your AE script
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
archive = OUTPUT_DIR / "archive"
|
||||||
|
archive.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts_dir = archive / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
ts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate figures
|
||||||
|
for eval_type in EVALS:
|
||||||
|
for dim in LATENT_DIMS:
|
||||||
|
make_figures_for_dim(
|
||||||
|
df, eval_type=eval_type, latent_dim=dim, out_dir=ts_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy this script for provenance
|
||||||
|
script_path = Path(__file__)
|
||||||
|
try:
|
||||||
|
shutil.copy2(script_path, ts_dir)
|
||||||
|
except Exception:
|
||||||
|
pass # best effort if running in environments where __file__ may not exist
|
||||||
|
|
||||||
|
# Update "latest"
|
||||||
|
latest = OUTPUT_DIR / "latest"
|
||||||
|
latest.mkdir(parents=True, exist_ok=True)
|
||||||
|
for f in latest.iterdir():
|
||||||
|
if f.is_file():
|
||||||
|
f.unlink()
|
||||||
|
for f in ts_dir.iterdir():
|
||||||
|
if f.is_file():
|
||||||
|
shutil.copy2(f, latest / f.name)
|
||||||
|
|
||||||
|
print(f"Saved plots to: {ts_dir}")
|
||||||
|
print(f"Also updated: {latest}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
274
tools/print_results_structure.py
Normal file
274
tools/print_results_structure.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# --- CONFIG ---
|
||||||
|
ROOT = Path("/home/fedex/mt/results/done") # <- adjust if needed
|
||||||
|
# MODELS = ["deepsad", "isoforest", "ocsvm"]
|
||||||
|
MODELS = ["ae"]
|
||||||
|
|
||||||
|
# How much to show for very large collections
|
||||||
|
MAX_KEYS = 100 # show up to this many dict keys explicitly
|
||||||
|
MAX_GROUPS = 10 # distinct element-structure groups to print for a sequence
|
||||||
|
SAMPLE_PER_GROUP = 1 # recurse into this many representative elements per group
|
||||||
|
INDENT = " "
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Signature helpers ----------
|
||||||
|
def one_line_sig(obj: Any) -> str:
|
||||||
|
"""Single-line structural signature for grouping & summary."""
|
||||||
|
t = type(obj)
|
||||||
|
tn = t.__name__
|
||||||
|
|
||||||
|
# numpy arrays
|
||||||
|
if isinstance(obj, np.ndarray):
|
||||||
|
return f"ndarray(shape={tuple(obj.shape)}, dtype={obj.dtype})"
|
||||||
|
|
||||||
|
# numpy scalars
|
||||||
|
if isinstance(obj, (np.generic,)):
|
||||||
|
return f"{tn}"
|
||||||
|
|
||||||
|
# scalars / strings / None
|
||||||
|
if isinstance(obj, (int, float, bool, str, type(None))):
|
||||||
|
return tn
|
||||||
|
|
||||||
|
# dict
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# do not expand values here; just list key count
|
||||||
|
return f"dict(len={len(obj)})"
|
||||||
|
|
||||||
|
# list / tuple
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
return f"{tn}(len={len(obj)})"
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
return tn
|
||||||
|
|
||||||
|
|
||||||
|
def is_atomic(obj: Any) -> bool:
|
||||||
|
"""Atomic = we don't recurse further (except printing info)."""
|
||||||
|
return isinstance(obj, (int, float, bool, str, type(None), np.generic, np.ndarray))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Recursive pretty-printer ----------
|
||||||
|
def describe(obj: Any, path: str = "", indent: str = "", seen: Set[int] | None = None):
|
||||||
|
if seen is None:
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
# cycle guard
|
||||||
|
oid = id(obj)
|
||||||
|
if oid in seen:
|
||||||
|
print(f"{indent}{path}: <CYCLE {one_line_sig(obj)}>")
|
||||||
|
return
|
||||||
|
seen.add(oid)
|
||||||
|
|
||||||
|
# base info line
|
||||||
|
header = (
|
||||||
|
f"{indent}{path}: {one_line_sig(obj)}"
|
||||||
|
if path
|
||||||
|
else f"{indent}{one_line_sig(obj)}"
|
||||||
|
)
|
||||||
|
print(header)
|
||||||
|
|
||||||
|
# atomic: done
|
||||||
|
if is_atomic(obj):
|
||||||
|
return
|
||||||
|
|
||||||
|
# dict: print keys and recurse on each value
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
keys = list(obj.keys())
|
||||||
|
show = keys[:MAX_KEYS]
|
||||||
|
extra = len(keys) - len(show)
|
||||||
|
for k in show:
|
||||||
|
# format key nicely
|
||||||
|
key_repr = repr(k) if not isinstance(k, str) else k
|
||||||
|
next_path = f"{path}.{key_repr}" if path else f"{key_repr}"
|
||||||
|
try:
|
||||||
|
describe(obj[k], next_path, indent + INDENT, seen)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{indent}{INDENT}{next_path}: <ERROR {e}>")
|
||||||
|
if extra > 0:
|
||||||
|
print(f"{indent}{INDENT}... (+{extra} more keys)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# sequence (list/tuple): group by element structure
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
n = len(obj)
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# group by element signature
|
||||||
|
groups: Dict[str, List[int]] = {}
|
||||||
|
for i, el in enumerate(obj):
|
||||||
|
s = one_line_sig(el)
|
||||||
|
groups.setdefault(s, []).append(i)
|
||||||
|
|
||||||
|
# If too many distinct groups, truncate
|
||||||
|
group_items = list(groups.items())
|
||||||
|
if len(group_items) > MAX_GROUPS:
|
||||||
|
group_items = group_items[:MAX_GROUPS]
|
||||||
|
print(
|
||||||
|
f"{indent}{INDENT}... ({len(groups) - MAX_GROUPS} more groups hidden)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# print group headers and recurse into representative(s)
|
||||||
|
for sig, idxs in group_items:
|
||||||
|
count = len(idxs)
|
||||||
|
print(f"{indent}{INDENT}- elements with structure {sig}: count={count}")
|
||||||
|
# sample a few representatives from this group
|
||||||
|
for j in idxs[:SAMPLE_PER_GROUP]:
|
||||||
|
rep = obj[j]
|
||||||
|
rep_path = f"{path}[{j}]" if path else f"[{j}]"
|
||||||
|
try:
|
||||||
|
if is_atomic(rep):
|
||||||
|
print(
|
||||||
|
f"{indent}{INDENT}{INDENT}{rep_path}: {one_line_sig(rep)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
describe(rep, rep_path, indent + INDENT * 2, seen)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{indent}{INDENT}{INDENT}{rep_path}: <ERROR {e}>")
|
||||||
|
return
|
||||||
|
|
||||||
|
# fallback: nothing more to do
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Traverse & aggregate path->signature (optional summary) ----------
|
||||||
|
def traverse_paths(
|
||||||
|
obj: Any, prefix: str, out: Dict[str, Set[str]], seen: Set[int] | None = None
|
||||||
|
):
|
||||||
|
if seen is None:
|
||||||
|
seen = set()
|
||||||
|
oid = id(obj)
|
||||||
|
if oid in seen:
|
||||||
|
return
|
||||||
|
seen.add(oid)
|
||||||
|
|
||||||
|
out.setdefault(prefix or "<root>", set()).add(one_line_sig(obj))
|
||||||
|
|
||||||
|
if is_atomic(obj):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for k, v in obj.items():
|
||||||
|
key = repr(k) if not isinstance(k, str) else k
|
||||||
|
path = f"{prefix}.{key}" if prefix else key
|
||||||
|
traverse_paths(v, path, out, seen)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
# record element signatures but don't descend into *all* elements; just the first of each sig
|
||||||
|
sig_to_index = {}
|
||||||
|
for i, el in enumerate(obj):
|
||||||
|
s = one_line_sig(el)
|
||||||
|
if s not in sig_to_index:
|
||||||
|
sig_to_index[s] = i
|
||||||
|
for s, i in sig_to_index.items():
|
||||||
|
path = f"{prefix}[]" if prefix else "[]"
|
||||||
|
traverse_paths(obj[i], path, out, seen)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Per-pickle entry point ----------
|
||||||
|
def inspect_pickle_file(pkl_path: Path) -> Dict[str, Set[str]]:
|
||||||
|
with pkl_path.open("rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"\n=== {pkl_path.name} ===")
|
||||||
|
print("Top-level structure:")
|
||||||
|
describe(data)
|
||||||
|
|
||||||
|
# optional: aggregate a concise path->signature summary for later
|
||||||
|
agg: Dict[str, Set[str]] = {}
|
||||||
|
traverse_paths(data, "", agg)
|
||||||
|
return agg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Per-experiment ----------
|
||||||
|
def read_kfold_num(exp_dir: Path) -> int:
|
||||||
|
cfg = exp_dir / "config.json"
|
||||||
|
if not cfg.exists():
|
||||||
|
print(f"[warn] {exp_dir.name}: missing config.json")
|
||||||
|
return 0
|
||||||
|
with cfg.open("r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
if not config.get("k_fold"):
|
||||||
|
print(f"[warn] {exp_dir.name}: config says not k-fold")
|
||||||
|
return 0
|
||||||
|
return int(config["k_fold_num"])
|
||||||
|
|
||||||
|
|
||||||
|
def inspect_experiment_folder(exp_dir: Path, models: List[str]) -> Dict[str, Set[str]]:
|
||||||
|
k = read_kfold_num(exp_dir)
|
||||||
|
if k <= 0:
|
||||||
|
return {}
|
||||||
|
agg: Dict[str, Set[str]] = {}
|
||||||
|
for model in models:
|
||||||
|
pkl = exp_dir / f"results_{model}_0.pkl" # fold 0 by design
|
||||||
|
if not pkl.exists():
|
||||||
|
print(f"[warn] Missing {pkl.name} in {exp_dir.name}")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
sigs = inspect_pickle_file(pkl)
|
||||||
|
# merge
|
||||||
|
for path, sigset in sigs.items():
|
||||||
|
agg.setdefault(path, set()).update(sigset)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[error] Failed reading {pkl}: {e}")
|
||||||
|
return agg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Driver ----------
|
||||||
|
def main(root: Path, models: List[str]):
|
||||||
|
if not root.exists():
|
||||||
|
print(f"[error] ROOT not found: {root.resolve()}")
|
||||||
|
return
|
||||||
|
|
||||||
|
exp_dirs = [p for p in root.iterdir() if p.is_dir()]
|
||||||
|
if not exp_dirs:
|
||||||
|
print(f"[warn] No experiment subdirectories under {root}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(exp_dirs)} experiment dirs under {root}\n")
|
||||||
|
|
||||||
|
efficient_done, lenet_done = False, False
|
||||||
|
global_agg: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
|
for exp in sorted(exp_dirs):
|
||||||
|
if efficient_done and lenet_done:
|
||||||
|
print("\nBoth efficient and lenet done, stopping early.")
|
||||||
|
break
|
||||||
|
lowname = exp.name.lower()
|
||||||
|
if efficient_done and "efficient" in lowname:
|
||||||
|
continue
|
||||||
|
if lenet_done and "lenet" in lowname:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n================ EXPERIMENT: {exp.name} ================")
|
||||||
|
agg = inspect_experiment_folder(exp, models)
|
||||||
|
if agg:
|
||||||
|
if "efficient" in lowname:
|
||||||
|
efficient_done = True
|
||||||
|
if "lenet" in lowname:
|
||||||
|
lenet_done = True
|
||||||
|
for path, sigset in agg.items():
|
||||||
|
global_agg.setdefault(path, set()).update(sigset)
|
||||||
|
|
||||||
|
print("\n\n================ GLOBAL STRUCTURE SUMMARY ================")
|
||||||
|
for path in sorted(global_agg.keys()):
|
||||||
|
shapes = sorted(global_agg[path])
|
||||||
|
print(f"\n{path}:")
|
||||||
|
for s in shapes:
|
||||||
|
print(f" - {s}")
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(ROOT, MODELS)
|
||||||
@@ -1,31 +1,9 @@
|
|||||||
[tool.poetry]
|
[project]
|
||||||
name = "tools"
|
name = "tools"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = ""
|
description = "Add your description here"
|
||||||
authors = ["Your Name <you@example.com>"]
|
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
package-mode = false
|
requires-python = ">=3.11.9"
|
||||||
|
dependencies = [
|
||||||
[tool.poetry.dependencies]
|
"polars>=1.33.0",
|
||||||
python = ">=3.11,<3.12"
|
]
|
||||||
pointcloudset = "^0.9.0"
|
|
||||||
open3d = "^0.19.0"
|
|
||||||
scikit-learn = "^1.4.2"
|
|
||||||
dash = "^2.16.1"
|
|
||||||
addict = "^2.4.0"
|
|
||||||
pillow = "^10.3.0"
|
|
||||||
tqdm = "^4.66.2"
|
|
||||||
matplotlib = "^3.8.4"
|
|
||||||
dask = "^2024.4.2"
|
|
||||||
dask-expr = "^1.1.3"
|
|
||||||
pandas = "^2.2.2"
|
|
||||||
pathvalidate = "^3.2.0"
|
|
||||||
tabulate = "^0.9.0"
|
|
||||||
tensorflow-datasets = "^4.9.8"
|
|
||||||
tensorflow = "^2.19.0"
|
|
||||||
setuptools = "^79.0.1"
|
|
||||||
umap-learn = "^0.5.7"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|||||||
28
tools/uv.lock
generated
Normal file
28
tools/uv.lock
generated
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
version = 1
|
||||||
|
revision = 2
|
||||||
|
requires-python = ">=3.11.9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "polars"
|
||||||
|
version = "1.33.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b6/3f/d8bc150b548a486f2559586ec6455c2566b9d2fb7ee1acae90ddca14eec1/polars-1.33.0.tar.gz", hash = "sha256:50ad2ab96c701be2c6ac9b584d9aa6a385f228f6c06de15b88c5d10df7990d56", size = 4811393, upload-time = "2025-09-01T16:32:46.106Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/23/8c/0c4ac34030348ed547b27db0ae7d77ccd12dc4008e91c4f8e896c3161ed8/polars-1.33.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:664ef1c0988e4098518c6acfdd5477f2e11611f4ac8a269db55b94ea4978d0e5", size = 38793275, upload-time = "2025-09-01T16:31:51.038Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/95/2a/87e27ef3cb76e54f92dd177b9f4c80329d66e78f51ed968e9bdf452ccfb1/polars-1.33.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:2477b720c466914549f0f2cfc69f617a602d91e9d90205b64d795ed1ecf99b3c", size = 35238137, upload-time = "2025-09-01T16:31:55.179Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f2/e2/485c87047e8aaae8dae4e9881517697616b7f79b14132961fbccfc386b29/polars-1.33.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd9b76abc22fdb20a005c629ee8c056b0545433f18854b929fb54e351d1b98ee", size = 39341268, upload-time = "2025-09-01T16:31:58.269Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b9/3a/39d784ed547832eb6cbe86cc7f3a6353fa977803e0cec743dd5932ecf50b/polars-1.33.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:6e78026c2ece38c45c6ee0416e2594980652d89deee13a15bd9f83743ec8fa8d", size = 36262606, upload-time = "2025-09-01T16:32:01.981Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/1b/4aea12acf2301f4d7fe78b9f4b54611ec2187439fa299e986974cfd956f2/polars-1.33.0-cp39-abi3-win_amd64.whl", hash = "sha256:7973568178117667871455d7969c1929abb890597964ca89290bfd89e4366980", size = 38919180, upload-time = "2025-09-01T16:32:05.087Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/58/13/824a81b43199202fc859c24515cd5b227930d6dce0dea488e4b415edbaba/polars-1.33.0-cp39-abi3-win_arm64.whl", hash = "sha256:c7d614644eda028907965f8203ac54b9a4f5b90303de2723bf1c1087433a0914", size = 35033820, upload-time = "2025-09-01T16:32:08.116Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tools"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = { virtual = "." }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "polars" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.metadata]
|
||||||
|
requires-dist = [{ name = "polars", specifier = ">=1.33.0" }]
|
||||||
295
tools/verify_loaded_results.py
Normal file
295
tools/verify_loaded_results.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from load_results import load_results_dataframe
|
||||||
|
|
||||||
|
# --- configure your intended grid here (use the *canonical* strings used in df) ---
|
||||||
|
NETWORKS_EXPECTED = ["subter_LeNet", "subter_efficient"]
|
||||||
|
LATENT_DIMS_EXPECTED = [32, 64, 128, 256, 512, 768, 1024]
|
||||||
|
SEMI_LABELS_EXPECTED = [(0, 0), (50, 10), (500, 100)]
|
||||||
|
MODELS_EXPECTED = ["deepsad", "isoforest", "ocsvm"]
|
||||||
|
EVALS_EXPECTED = ["exp_based", "manual_based"]
|
||||||
|
|
||||||
|
# If k-fold is uniform, set it. If None, we infer it *per combo* from df.
|
||||||
|
EXPECTED_K_FOLD: int | None = None # e.g., 3
|
||||||
|
# utils/shape_checks.py
|
||||||
|
|
||||||
|
|
||||||
|
def equal_within_tolerance(lengths: Sequence[int], tol: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
True iff max(lengths) - min(lengths) <= tol.
|
||||||
|
Empty/one-item sequences return True.
|
||||||
|
"""
|
||||||
|
if not lengths:
|
||||||
|
return True
|
||||||
|
mn = min(lengths)
|
||||||
|
mx = max(lengths)
|
||||||
|
return (mx - mn) <= tol
|
||||||
|
|
||||||
|
|
||||||
|
def add_shape_columns(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
return df.with_columns(
|
||||||
|
# scores length
|
||||||
|
scores_len=pl.when(pl.col("scores").is_null())
|
||||||
|
.then(None)
|
||||||
|
.otherwise(pl.col("scores").list.len()),
|
||||||
|
# deepsad-only arrays (None for others)
|
||||||
|
idxs_len=pl.when(pl.col("sample_indices").is_null())
|
||||||
|
.then(None)
|
||||||
|
.otherwise(pl.col("sample_indices").list.len()),
|
||||||
|
labels_len=pl.when(pl.col("sample_labels").is_null())
|
||||||
|
.then(None)
|
||||||
|
.otherwise(pl.col("sample_labels").list.len()),
|
||||||
|
vmask_len=pl.when(pl.col("valid_mask").is_null())
|
||||||
|
.then(None)
|
||||||
|
.otherwise(pl.col("valid_mask").list.len()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_grid_coverage_and_shapes(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
networks=NETWORKS_EXPECTED,
|
||||||
|
latent_dims=LATENT_DIMS_EXPECTED,
|
||||||
|
semi_labels=SEMI_LABELS_EXPECTED,
|
||||||
|
models=MODELS_EXPECTED,
|
||||||
|
evals=EVALS_EXPECTED,
|
||||||
|
expected_k_fold: int | None = EXPECTED_K_FOLD,
|
||||||
|
):
|
||||||
|
dfx = add_shape_columns(df)
|
||||||
|
|
||||||
|
# helper: get rows for a specific base combo
|
||||||
|
def subframe(net, lat, s_norm, s_anom, mdl, ev):
|
||||||
|
return dfx.filter(
|
||||||
|
(pl.col("network") == net)
|
||||||
|
& (pl.col("latent_dim") == lat)
|
||||||
|
& (pl.col("semi_normals") == s_norm)
|
||||||
|
& (pl.col("semi_anomalous") == s_anom)
|
||||||
|
& (pl.col("model") == mdl)
|
||||||
|
& (pl.col("eval") == ev)
|
||||||
|
)
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
incomplete = [] # combos missing folds
|
||||||
|
shape_inconsistent = [] # within-combo, across-fold score/idx/label/vmask mismatches
|
||||||
|
cross_model_diffs = [] # across models at fixed (net,lat,semi,eval): scores_len only
|
||||||
|
|
||||||
|
# 1) Coverage + within-combo shapes
|
||||||
|
for net, lat, (s_norm, s_anom), mdl, ev in product(
|
||||||
|
networks, latent_dims, semi_labels, models, evals
|
||||||
|
):
|
||||||
|
sf = subframe(net, lat, s_norm, s_anom, mdl, ev).select(
|
||||||
|
"fold",
|
||||||
|
"k_fold_num",
|
||||||
|
"scores_len",
|
||||||
|
"idxs_len",
|
||||||
|
"labels_len",
|
||||||
|
"vmask_len",
|
||||||
|
)
|
||||||
|
|
||||||
|
if sf.height == 0:
|
||||||
|
missing.append(
|
||||||
|
dict(
|
||||||
|
network=net,
|
||||||
|
latent_dim=lat,
|
||||||
|
semi_normals=s_norm,
|
||||||
|
semi_anomalous=s_anom,
|
||||||
|
model=mdl,
|
||||||
|
eval=ev,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# folds present vs expected
|
||||||
|
folds_present = sorted(sf.get_column("fold").unique().to_list())
|
||||||
|
if expected_k_fold is not None:
|
||||||
|
kexp = expected_k_fold
|
||||||
|
else:
|
||||||
|
kexp = int(sf.get_column("k_fold_num").max())
|
||||||
|
all_expected_folds = list(range(kexp))
|
||||||
|
if folds_present != all_expected_folds:
|
||||||
|
incomplete.append(
|
||||||
|
dict(
|
||||||
|
network=net,
|
||||||
|
latent_dim=lat,
|
||||||
|
semi_normals=s_norm,
|
||||||
|
semi_anomalous=s_anom,
|
||||||
|
model=mdl,
|
||||||
|
eval=ev,
|
||||||
|
expected_folds=all_expected_folds,
|
||||||
|
present_folds=folds_present,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape consistency across folds (for this combo)
|
||||||
|
shape_cols = ["scores_len", "idxs_len", "labels_len", "vmask_len"]
|
||||||
|
for colname in shape_cols:
|
||||||
|
vals = sf.select(colname).to_series()
|
||||||
|
uniq = sorted({v for v in vals.to_list()})
|
||||||
|
# allow None-only columns (e.g., deepsad-only fields for other models)
|
||||||
|
if len([u for u in uniq if u is not None]) > 1:
|
||||||
|
per_fold = (
|
||||||
|
sf.select("fold", pl.col(colname))
|
||||||
|
.sort("fold")
|
||||||
|
.to_dict(as_series=False)
|
||||||
|
)
|
||||||
|
shape_inconsistent.append(
|
||||||
|
dict(
|
||||||
|
network=net,
|
||||||
|
latent_dim=lat,
|
||||||
|
semi_normals=s_norm,
|
||||||
|
semi_anomalous=s_anom,
|
||||||
|
model=mdl,
|
||||||
|
eval=ev,
|
||||||
|
metric=colname,
|
||||||
|
per_fold=per_fold,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Cross-model comparability at fixed (net,lat,semi,eval)
|
||||||
|
# Only check number of test scores; ignore ROC/PRC binning entirely.
|
||||||
|
base_keys = (
|
||||||
|
df.select("network", "latent_dim", "semi_normals", "semi_anomalous", "eval")
|
||||||
|
.unique()
|
||||||
|
.iter_rows()
|
||||||
|
)
|
||||||
|
for net, lat, s_norm, s_anom, ev in base_keys:
|
||||||
|
rows = (
|
||||||
|
dfx.filter(
|
||||||
|
(pl.col("network") == net)
|
||||||
|
& (pl.col("latent_dim") == lat)
|
||||||
|
& (pl.col("semi_normals") == s_norm)
|
||||||
|
& (pl.col("semi_anomalous") == s_anom)
|
||||||
|
& (pl.col("eval") == ev)
|
||||||
|
)
|
||||||
|
.group_by("model")
|
||||||
|
.agg(
|
||||||
|
pl.col("scores_len")
|
||||||
|
.drop_nulls()
|
||||||
|
.unique()
|
||||||
|
.sort()
|
||||||
|
.alias("scores_len_set"),
|
||||||
|
)
|
||||||
|
.to_dict(as_series=False)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mdls = rows["model"]
|
||||||
|
s_sets = [rows["scores_len_set"][i] for i in range(len(mdls))]
|
||||||
|
# normalize: empty => ignore that model (no scores); single value => int; else => list
|
||||||
|
norm = {}
|
||||||
|
for m, vals in zip(mdls, s_sets):
|
||||||
|
if len(vals) == 0:
|
||||||
|
continue
|
||||||
|
norm[m] = vals[0] if len(vals) == 1 else list(vals)
|
||||||
|
|
||||||
|
if len(norm) > 1:
|
||||||
|
# Compare as tuples to allow list values
|
||||||
|
normalized_keys = [
|
||||||
|
v if isinstance(v, int) else tuple(v) for v in norm.values()
|
||||||
|
]
|
||||||
|
if len(set(normalized_keys)) > 1:
|
||||||
|
cross_model_diffs.append(
|
||||||
|
dict(
|
||||||
|
network=net,
|
||||||
|
latent_dim=lat,
|
||||||
|
semi_normals=s_norm,
|
||||||
|
semi_anomalous=s_anom,
|
||||||
|
eval=ev,
|
||||||
|
metric="scores_len",
|
||||||
|
by_model=norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Print a readable report ---
|
||||||
|
print("\n=== GRID COVERAGE ===")
|
||||||
|
print(f"Missing combos: {len(missing)}")
|
||||||
|
for m in missing[:20]:
|
||||||
|
print(" ", m)
|
||||||
|
if len(missing) > 20:
|
||||||
|
print(f" ... (+{len(missing) - 20} more)")
|
||||||
|
|
||||||
|
print("\nIncomplete combos (folds missing):", len(incomplete))
|
||||||
|
for inc in incomplete[:20]:
|
||||||
|
print(
|
||||||
|
" ",
|
||||||
|
{
|
||||||
|
k: inc[k]
|
||||||
|
for k in [
|
||||||
|
"network",
|
||||||
|
"latent_dim",
|
||||||
|
"semi_normals",
|
||||||
|
"semi_anomalous",
|
||||||
|
"model",
|
||||||
|
"eval",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"expected",
|
||||||
|
inc["expected_folds"],
|
||||||
|
"present",
|
||||||
|
inc["present_folds"],
|
||||||
|
)
|
||||||
|
if len(incomplete) > 20:
|
||||||
|
print(f" ... (+{len(incomplete) - 20} more)")
|
||||||
|
|
||||||
|
print("\n=== WITHIN-COMBO SHAPE CONSISTENCY (across folds) ===")
|
||||||
|
print(f"Mismatching groups: {len(shape_inconsistent)}")
|
||||||
|
for s in shape_inconsistent[:15]:
|
||||||
|
hdr = {
|
||||||
|
k: s[k]
|
||||||
|
for k in [
|
||||||
|
"network",
|
||||||
|
"latent_dim",
|
||||||
|
"semi_normals",
|
||||||
|
"semi_anomalous",
|
||||||
|
"model",
|
||||||
|
"eval",
|
||||||
|
"metric",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
print(" ", hdr, "values:", s["per_fold"])
|
||||||
|
if len(shape_inconsistent) > 15:
|
||||||
|
print(f" ... (+{len(shape_inconsistent) - 15} more)")
|
||||||
|
|
||||||
|
print("\n=== CROSS-MODEL COMPARABILITY (by shape) ===")
|
||||||
|
print(
|
||||||
|
f"Differences across models at fixed (net,lat,semi,eval): {len(cross_model_diffs)}"
|
||||||
|
)
|
||||||
|
for s in cross_model_diffs[:15]:
|
||||||
|
hdr = {
|
||||||
|
k: s[k]
|
||||||
|
for k in [
|
||||||
|
"network",
|
||||||
|
"latent_dim",
|
||||||
|
"semi_normals",
|
||||||
|
"semi_anomalous",
|
||||||
|
"eval",
|
||||||
|
"metric",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
print(" ", hdr, "by_model:", s["by_model"])
|
||||||
|
if len(cross_model_diffs) > 15:
|
||||||
|
print(f" ... (+{len(cross_model_diffs) - 15} more)")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"missing": missing,
|
||||||
|
"incomplete": incomplete,
|
||||||
|
"shape_inconsistent": shape_inconsistent,
|
||||||
|
"cross_model_diffs": cross_model_diffs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root = Path("/home/fedex/mt/results/done")
|
||||||
|
df = load_results_dataframe(root, allow_cache=True)
|
||||||
|
report = check_grid_coverage_and_shapes(df)
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user