retest implemented and fixed missing center in save data

This commit is contained in:
Jan Kowalczyk
2025-07-01 17:22:29 +02:00
parent 24c6771576
commit 4863b91127
6 changed files with 189 additions and 1307 deletions

View File

@@ -1,192 +0,0 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_2": {
"inputs": {
"systems": "systems_2"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nix-github-actions": {
"inputs": {
"nixpkgs": [
"poetry2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1703863825,
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
"owner": "nix-community",
"repo": "nix-github-actions",
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "nix-github-actions",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1719327525,
"narHash": "sha256-fPWiFM4aYbK9zGTt3KJ9CwX//iyElRiNHWNj2hk3i0E=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "191a3fd9786d09c8d82e89ed68c4463e7be09b3e",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable-small",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-newest": {
"locked": {
"lastModified": 1749285348,
"narHash": "sha256-frdhQvPbmDYaScPFiCnfdh3B/Vh81Uuoo0w5TkWmmjU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "3e3afe5174c561dee0df6f2c2b2236990146329f",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"poetry2nix": {
"inputs": {
"flake-utils": "flake-utils_2",
"nix-github-actions": "nix-github-actions",
"nixpkgs": [
"nixpkgs"
],
"systems": "systems_3",
"treefmt-nix": "treefmt-nix"
},
"locked": {
"lastModified": 1719358925,
"narHash": "sha256-ZV/2YB7nyeYCsDm6EMH0EKtlpxuu2ImEd5WrlceNwRE=",
"owner": "nix-community",
"repo": "poetry2nix",
"rev": "bbc1ee74fc1ac4082f617bf32f1c927e759717d2",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "poetry2nix",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs",
"nixpkgs-newest": "nixpkgs-newest",
"poetry2nix": "poetry2nix"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_2": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_3": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"id": "systems",
"type": "indirect"
}
},
"treefmt-nix": {
"inputs": {
"nixpkgs": [
"poetry2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1718522839,
"narHash": "sha256-ULzoKzEaBOiLRtjeY3YoGFJMwWSKRYOic6VNw2UyTls=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "68eb1dc333ce82d0ab0c0357363ea17c31ea1f81",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "treefmt-nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

View File

@@ -1,83 +0,0 @@
{
description = "Deepsad devenv with python 3.11";
inputs = {
flake-utils.url = "github:numtide/flake-utils";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
# Added newest nixpkgs for an updated poetry package.
nixpkgs-newest.url = "github:NixOS/nixpkgs/nixos-unstable";
poetry2nix = {
url = "github:nix-community/poetry2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
};
outputs =
{
self,
nixpkgs,
nixpkgs-newest,
flake-utils,
poetry2nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
config.cudaSupport = true;
};
pkgsNew = nixpkgs-newest.legacyPackages.${system};
thundersvm = import ./nix/thundersvm.nix {
inherit pkgs;
inherit (pkgs) fetchFromGitHub cmake gcc12Stdenv;
cudaPackages = pkgs.cudaPackages;
};
thundersvm-python = import ./nix/thundersvm-python.nix {
inherit pkgs;
pythonPackages = pkgs.python311Packages;
thundersvm = thundersvm;
};
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication defaultPoetryOverrides;
in
{
packages = {
deepsad = mkPoetryApplication {
projectDir = self;
preferWheels = true;
python = pkgs.python311;
overrides = defaultPoetryOverrides.extend (
final: prev: {
torch-receptive-field = prev.torch-receptive-field.overridePythonAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [ prev.setuptools ];
});
}
);
};
default = self.packages.${system}.deepsad;
};
devShells.default = pkgs.mkShell {
inputsFrom = [ self.packages.${system}.deepsad ];
buildInputs = with pkgs.python311Packages; [
torch-bin
torchvision-bin
thundersvm-python
];
#LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [
#pkgs.stdenv.cc.cc
#];
};
devShells.poetry = pkgs.mkShell {
packages = [
pkgsNew.poetry
pkgs.python311
];
};
}
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,32 +1,29 @@
[tool.poetry] [project]
name = "deep-sad-pytorch" name = "deep-sad-pytorch"
version = "0.1.0" version = "0.1.0"
description = "" description = "Add your description here"
authors = ["Your Name <you@example.com>"]
readme = "README.md" readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"click>=8.2.1",
"cvxopt>=1.3.2",
"cycler>=0.12.1",
"joblib>=1.5.1",
"kiwisolver>=1.4.8",
"matplotlib>=3.10.3",
"numpy>=2.3.1",
"pandas>=2.3.0",
"pillow>=11.2.1",
"pyparsing>=3.2.3",
"python-dateutil>=2.9.0.post0",
"pytz>=2025.2",
"scikit-learn>=1.7.0",
"scipy>=1.16.0",
"seaborn>=0.13.2",
"six>=1.17.0",
"torch-receptive-field",
"torchscan>=0.1.1",
]
[tool.poetry.dependencies] [tool.uv.sources]
python = ">=3.11,<3.12" torch-receptive-field = { git = "https://github.com/Fangyh09/pytorch-receptive-field.git" }
click = "^8.1.7"
matplotlib = "^3.9.0"
numpy = "^2.0.0"
pandas = "^2.2.2"
cvxopt = "^1.3.2"
cycler = "^0.12.1"
joblib = "^1.4.2"
kiwisolver = "^1.4.5"
pillow = "^10.3.0"
pyparsing = "^3.1.2"
python-dateutil = "^2.9.0.post0"
pytz = "^2024.1"
scikit-learn = "^1.5.0"
scipy = "^1.14.0"
seaborn = "^0.13.2"
six = "^1.16.0"
torchscan = "^0.1.2"
torch-receptive-field = {git = "https://github.com/Fangyh09/pytorch-receptive-field.git"}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -126,6 +126,8 @@ class DeepSAD(object):
) )
# Get the model # Get the model
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx) self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
# Store training results including indices # Store training results including indices
self.results["train"]["time"] = self.trainer.train_time self.results["train"]["time"] = self.trainer.train_time
self.results["train"]["indices"] = self.trainer.train_indices self.results["train"]["indices"] = self.trainer.train_indices
@@ -333,7 +335,7 @@ class DeepSAD(object):
# load autoencoder parameters if specified # load autoencoder parameters if specified
if load_ae: if load_ae:
if self.ae_net is None: if self.ae_net is None:
self.ae_net = build_autoencoder(self.net_name) self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
self.ae_net.load_state_dict(model_dict["ae_net_dict"]) self.ae_net.load_state_dict(model_dict["ae_net_dict"])
def save_results(self, export_pkl): def save_results(self, export_pkl):

View File

@@ -25,7 +25,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
[ [
"train", "train",
"infer", "infer",
"ae_elbow_test", # Add new action "ae_elbow_test",
"retest",
] ]
), ),
) )
@@ -773,6 +774,165 @@ def main(
) )
else: else:
logger.error(f"Unknown action: {action}") logger.error(f"Unknown action: {action}")
elif action == "retest":
if (
not load_model
or not Path(load_model).exists()
or not Path(load_model).is_dir()
):
logger.error(
"For retest mode a model directory has to be loaded! Pass the --load_model option with the model directory path!"
)
return
load_model = Path(load_model)
if not load_config:
logger.error(
"For retest mode a config has to be loaded! Pass the --load_config option with the config path!"
)
return
dataset = load_dataset(
cfg.settings["dataset_name"],
data_path,
cfg.settings["normal_class"],
cfg.settings["known_outlier_class"],
cfg.settings["n_known_outlier_classes"],
cfg.settings["ratio_known_normal"],
cfg.settings["ratio_known_outlier"],
cfg.settings["ratio_pollution"],
random_state=np.random.RandomState(cfg.settings["seed"]),
k_fold_num=cfg.settings["k_fold_num"],
num_known_normal=cfg.settings["num_known_normal"],
num_known_outlier=cfg.settings["num_known_outlier"],
)
train_passes = (
range(cfg.settings["k_fold_num"]) if cfg.settings["k_fold"] else [None]
)
retest_isoforest = True
retest_ocsvm = True
retest_deepsad = True
for fold_idx in train_passes:
if fold_idx is None:
logger.info("Single train re-testing without k-fold")
deepsad_model_path = load_model / "model_deepsad.tar"
isoforest_model_path = load_model / "model_ocsvm.pkl"
ocsvm_model_path = load_model / "model_isoforest.pkl"
ae_model_path = load_model / "model_ae.tar"
else:
logger.info(f"Fold {fold_idx + 1}/{cfg.settings['k_fold_num']}")
deepsad_model_path = load_model / f"model_deepsad_{fold_idx}.tar"
isoforest_model_path = load_model / f"model_isoforest_{fold_idx}.pkl"
ocsvm_model_path = load_model / f"model_ocsvm_{fold_idx}.pkl"
ae_model_path = load_model / f"model_ae_{fold_idx}.tar"
# Check which model files exist and which do not
model_paths = [
deepsad_model_path,
isoforest_model_path,
ocsvm_model_path,
ae_model_path,
]
missing_models = [
path.name
for path in model_paths
if not path.exists() or not path.is_file()
]
if missing_models:
logger.error(
f"The following model files do not exist: {', '.join(missing_models)}. Please check the load_model path."
)
return
# Initialize Isolation Forest model
if retest_isoforest:
Isoforest = IsoForest(
hybrid=False,
n_estimators=cfg.settings["isoforest_n_estimators"],
max_samples=cfg.settings["isoforest_max_samples"],
contamination=cfg.settings["isoforest_contamination"],
n_jobs=cfg.settings["isoforest_n_jobs_model"],
seed=cfg.settings["seed"],
)
Isoforest.load_model(import_path=isoforest_model_path, device=device)
Isoforest.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
)
# Initialize DeepSAD model and set neural network phi
if retest_deepsad:
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
deepSAD.set_network(cfg.settings["net_name"])
deepSAD.load_model(
model_path=deepsad_model_path, load_ae=True, map_location=device
)
logger.info("Loading model from %s." % load_model)
deepSAD.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
)
if retest_ocsvm:
ocsvm = OCSVM(
kernel=cfg.settings["ocsvm_kernel"],
nu=cfg.settings["ocsvm_nu"],
hybrid=True,
latent_space_dim=cfg.settings["latent_space_dim"],
)
ocsvm.load_ae(
net_name=cfg.settings["net_name"],
model_path=ae_model_path,
device=device,
)
ocsvm.load_model(import_path=ocsvm_model_path)
ocsvm.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
batch_size=256,
)
retest_output_path = load_model / "retest_output"
retest_output_path.mkdir(parents=True, exist_ok=True)
# Save results, model, and configuration
if fold_idx is None:
if retest_deepsad:
deepSAD.save_results(
export_pkl=retest_output_path / "results_deepsad.pkl"
)
if retest_ocsvm:
ocsvm.save_results(
export_pkl=retest_output_path / "results_ocsvm.pkl"
)
if retest_isoforest:
Isoforest.save_results(
export_pkl=retest_output_path / "results_isoforest.pkl"
)
else:
if retest_deepsad:
deepSAD.save_results(
export_pkl=retest_output_path
/ f"results_deepsad_{fold_idx}.pkl"
)
if retest_ocsvm:
ocsvm.save_results(
export_pkl=retest_output_path / f"/results_ocsvm_{fold_idx}.pkl"
)
if retest_isoforest:
Isoforest.save_results(
export_pkl=retest_output_path
/ f"/results_isoforest_{fold_idx}.pkl"
)
if __name__ == "__main__": if __name__ == "__main__":