retest implemented and fixed missing center in save data
This commit is contained in:
192
Deep-SAD-PyTorch/flake.lock
generated
192
Deep-SAD-PyTorch/flake.lock
generated
@@ -1,192 +0,0 @@
|
|||||||
{
|
|
||||||
"nodes": {
|
|
||||||
"flake-utils": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"flake-utils_2": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems_2"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nix-github-actions": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1703863825,
|
|
||||||
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1719327525,
|
|
||||||
"narHash": "sha256-fPWiFM4aYbK9zGTt3KJ9CwX//iyElRiNHWNj2hk3i0E=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "191a3fd9786d09c8d82e89ed68c4463e7be09b3e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "NixOS",
|
|
||||||
"ref": "nixos-unstable-small",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs-newest": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1749285348,
|
|
||||||
"narHash": "sha256-frdhQvPbmDYaScPFiCnfdh3B/Vh81Uuoo0w5TkWmmjU=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "3e3afe5174c561dee0df6f2c2b2236990146329f",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "NixOS",
|
|
||||||
"ref": "nixos-unstable",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"poetry2nix": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils_2",
|
|
||||||
"nix-github-actions": "nix-github-actions",
|
|
||||||
"nixpkgs": [
|
|
||||||
"nixpkgs"
|
|
||||||
],
|
|
||||||
"systems": "systems_3",
|
|
||||||
"treefmt-nix": "treefmt-nix"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1719358925,
|
|
||||||
"narHash": "sha256-ZV/2YB7nyeYCsDm6EMH0EKtlpxuu2ImEd5WrlceNwRE=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"rev": "bbc1ee74fc1ac4082f617bf32f1c927e759717d2",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils",
|
|
||||||
"nixpkgs": "nixpkgs",
|
|
||||||
"nixpkgs-newest": "nixpkgs-newest",
|
|
||||||
"poetry2nix": "poetry2nix"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_2": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_3": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"id": "systems",
|
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"treefmt-nix": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1718522839,
|
|
||||||
"narHash": "sha256-ULzoKzEaBOiLRtjeY3YoGFJMwWSKRYOic6VNw2UyTls=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"rev": "68eb1dc333ce82d0ab0c0357363ea17c31ea1f81",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": "root",
|
|
||||||
"version": 7
|
|
||||||
}
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
{
|
|
||||||
description = "Deepsad devenv with python 3.11";
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
|
||||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
|
|
||||||
# Added newest nixpkgs for an updated poetry package.
|
|
||||||
nixpkgs-newest.url = "github:NixOS/nixpkgs/nixos-unstable";
|
|
||||||
poetry2nix = {
|
|
||||||
url = "github:nix-community/poetry2nix";
|
|
||||||
inputs.nixpkgs.follows = "nixpkgs";
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
outputs =
|
|
||||||
{
|
|
||||||
self,
|
|
||||||
nixpkgs,
|
|
||||||
nixpkgs-newest,
|
|
||||||
flake-utils,
|
|
||||||
poetry2nix,
|
|
||||||
}:
|
|
||||||
flake-utils.lib.eachDefaultSystem (
|
|
||||||
system:
|
|
||||||
let
|
|
||||||
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
|
|
||||||
pkgs = import nixpkgs {
|
|
||||||
inherit system;
|
|
||||||
config.allowUnfree = true;
|
|
||||||
config.cudaSupport = true;
|
|
||||||
};
|
|
||||||
pkgsNew = nixpkgs-newest.legacyPackages.${system};
|
|
||||||
thundersvm = import ./nix/thundersvm.nix {
|
|
||||||
inherit pkgs;
|
|
||||||
inherit (pkgs) fetchFromGitHub cmake gcc12Stdenv;
|
|
||||||
cudaPackages = pkgs.cudaPackages;
|
|
||||||
};
|
|
||||||
|
|
||||||
thundersvm-python = import ./nix/thundersvm-python.nix {
|
|
||||||
inherit pkgs;
|
|
||||||
pythonPackages = pkgs.python311Packages;
|
|
||||||
thundersvm = thundersvm;
|
|
||||||
};
|
|
||||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication defaultPoetryOverrides;
|
|
||||||
in
|
|
||||||
{
|
|
||||||
packages = {
|
|
||||||
deepsad = mkPoetryApplication {
|
|
||||||
projectDir = self;
|
|
||||||
preferWheels = true;
|
|
||||||
python = pkgs.python311;
|
|
||||||
overrides = defaultPoetryOverrides.extend (
|
|
||||||
final: prev: {
|
|
||||||
torch-receptive-field = prev.torch-receptive-field.overridePythonAttrs (old: {
|
|
||||||
buildInputs = (old.buildInputs or [ ]) ++ [ prev.setuptools ];
|
|
||||||
});
|
|
||||||
}
|
|
||||||
);
|
|
||||||
};
|
|
||||||
default = self.packages.${system}.deepsad;
|
|
||||||
};
|
|
||||||
|
|
||||||
devShells.default = pkgs.mkShell {
|
|
||||||
inputsFrom = [ self.packages.${system}.deepsad ];
|
|
||||||
buildInputs = with pkgs.python311Packages; [
|
|
||||||
torch-bin
|
|
||||||
torchvision-bin
|
|
||||||
thundersvm-python
|
|
||||||
];
|
|
||||||
#LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [
|
|
||||||
#pkgs.stdenv.cc.cc
|
|
||||||
#];
|
|
||||||
};
|
|
||||||
|
|
||||||
devShells.poetry = pkgs.mkShell {
|
|
||||||
packages = [
|
|
||||||
pkgsNew.poetry
|
|
||||||
pkgs.python311
|
|
||||||
];
|
|
||||||
};
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
1002
Deep-SAD-PyTorch/poetry.lock
generated
1002
Deep-SAD-PyTorch/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,32 +1,29 @@
|
|||||||
[tool.poetry]
|
[project]
|
||||||
name = "deep-sad-pytorch"
|
name = "deep-sad-pytorch"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = ""
|
description = "Add your description here"
|
||||||
authors = ["Your Name <you@example.com>"]
|
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"click>=8.2.1",
|
||||||
|
"cvxopt>=1.3.2",
|
||||||
|
"cycler>=0.12.1",
|
||||||
|
"joblib>=1.5.1",
|
||||||
|
"kiwisolver>=1.4.8",
|
||||||
|
"matplotlib>=3.10.3",
|
||||||
|
"numpy>=2.3.1",
|
||||||
|
"pandas>=2.3.0",
|
||||||
|
"pillow>=11.2.1",
|
||||||
|
"pyparsing>=3.2.3",
|
||||||
|
"python-dateutil>=2.9.0.post0",
|
||||||
|
"pytz>=2025.2",
|
||||||
|
"scikit-learn>=1.7.0",
|
||||||
|
"scipy>=1.16.0",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"six>=1.17.0",
|
||||||
|
"torch-receptive-field",
|
||||||
|
"torchscan>=0.1.1",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.uv.sources]
|
||||||
python = ">=3.11,<3.12"
|
torch-receptive-field = { git = "https://github.com/Fangyh09/pytorch-receptive-field.git" }
|
||||||
click = "^8.1.7"
|
|
||||||
matplotlib = "^3.9.0"
|
|
||||||
numpy = "^2.0.0"
|
|
||||||
pandas = "^2.2.2"
|
|
||||||
cvxopt = "^1.3.2"
|
|
||||||
cycler = "^0.12.1"
|
|
||||||
joblib = "^1.4.2"
|
|
||||||
kiwisolver = "^1.4.5"
|
|
||||||
pillow = "^10.3.0"
|
|
||||||
pyparsing = "^3.1.2"
|
|
||||||
python-dateutil = "^2.9.0.post0"
|
|
||||||
pytz = "^2024.1"
|
|
||||||
scikit-learn = "^1.5.0"
|
|
||||||
scipy = "^1.14.0"
|
|
||||||
seaborn = "^0.13.2"
|
|
||||||
six = "^1.16.0"
|
|
||||||
torchscan = "^0.1.2"
|
|
||||||
torch-receptive-field = {git = "https://github.com/Fangyh09/pytorch-receptive-field.git"}
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|||||||
@@ -126,6 +126,8 @@ class DeepSAD(object):
|
|||||||
)
|
)
|
||||||
# Get the model
|
# Get the model
|
||||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||||
|
|
||||||
|
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||||
# Store training results including indices
|
# Store training results including indices
|
||||||
self.results["train"]["time"] = self.trainer.train_time
|
self.results["train"]["time"] = self.trainer.train_time
|
||||||
self.results["train"]["indices"] = self.trainer.train_indices
|
self.results["train"]["indices"] = self.trainer.train_indices
|
||||||
@@ -333,7 +335,7 @@ class DeepSAD(object):
|
|||||||
# load autoencoder parameters if specified
|
# load autoencoder parameters if specified
|
||||||
if load_ae:
|
if load_ae:
|
||||||
if self.ae_net is None:
|
if self.ae_net is None:
|
||||||
self.ae_net = build_autoencoder(self.net_name)
|
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
|
||||||
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
||||||
|
|
||||||
def save_results(self, export_pkl):
|
def save_results(self, export_pkl):
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
|||||||
[
|
[
|
||||||
"train",
|
"train",
|
||||||
"infer",
|
"infer",
|
||||||
"ae_elbow_test", # Add new action
|
"ae_elbow_test",
|
||||||
|
"retest",
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -773,6 +774,165 @@ def main(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown action: {action}")
|
logger.error(f"Unknown action: {action}")
|
||||||
|
elif action == "retest":
|
||||||
|
if (
|
||||||
|
not load_model
|
||||||
|
or not Path(load_model).exists()
|
||||||
|
or not Path(load_model).is_dir()
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
"For retest mode a model directory has to be loaded! Pass the --load_model option with the model directory path!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
load_model = Path(load_model)
|
||||||
|
if not load_config:
|
||||||
|
logger.error(
|
||||||
|
"For retest mode a config has to be loaded! Pass the --load_config option with the config path!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
cfg.settings["dataset_name"],
|
||||||
|
data_path,
|
||||||
|
cfg.settings["normal_class"],
|
||||||
|
cfg.settings["known_outlier_class"],
|
||||||
|
cfg.settings["n_known_outlier_classes"],
|
||||||
|
cfg.settings["ratio_known_normal"],
|
||||||
|
cfg.settings["ratio_known_outlier"],
|
||||||
|
cfg.settings["ratio_pollution"],
|
||||||
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||||
|
k_fold_num=cfg.settings["k_fold_num"],
|
||||||
|
num_known_normal=cfg.settings["num_known_normal"],
|
||||||
|
num_known_outlier=cfg.settings["num_known_outlier"],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_passes = (
|
||||||
|
range(cfg.settings["k_fold_num"]) if cfg.settings["k_fold"] else [None]
|
||||||
|
)
|
||||||
|
|
||||||
|
retest_isoforest = True
|
||||||
|
retest_ocsvm = True
|
||||||
|
retest_deepsad = True
|
||||||
|
|
||||||
|
for fold_idx in train_passes:
|
||||||
|
if fold_idx is None:
|
||||||
|
logger.info("Single train re-testing without k-fold")
|
||||||
|
deepsad_model_path = load_model / "model_deepsad.tar"
|
||||||
|
isoforest_model_path = load_model / "model_ocsvm.pkl"
|
||||||
|
ocsvm_model_path = load_model / "model_isoforest.pkl"
|
||||||
|
ae_model_path = load_model / "model_ae.tar"
|
||||||
|
else:
|
||||||
|
logger.info(f"Fold {fold_idx + 1}/{cfg.settings['k_fold_num']}")
|
||||||
|
|
||||||
|
deepsad_model_path = load_model / f"model_deepsad_{fold_idx}.tar"
|
||||||
|
isoforest_model_path = load_model / f"model_isoforest_{fold_idx}.pkl"
|
||||||
|
ocsvm_model_path = load_model / f"model_ocsvm_{fold_idx}.pkl"
|
||||||
|
ae_model_path = load_model / f"model_ae_{fold_idx}.tar"
|
||||||
|
|
||||||
|
# Check which model files exist and which do not
|
||||||
|
model_paths = [
|
||||||
|
deepsad_model_path,
|
||||||
|
isoforest_model_path,
|
||||||
|
ocsvm_model_path,
|
||||||
|
ae_model_path,
|
||||||
|
]
|
||||||
|
missing_models = [
|
||||||
|
path.name
|
||||||
|
for path in model_paths
|
||||||
|
if not path.exists() or not path.is_file()
|
||||||
|
]
|
||||||
|
if missing_models:
|
||||||
|
logger.error(
|
||||||
|
f"The following model files do not exist: {', '.join(missing_models)}. Please check the load_model path."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize Isolation Forest model
|
||||||
|
if retest_isoforest:
|
||||||
|
Isoforest = IsoForest(
|
||||||
|
hybrid=False,
|
||||||
|
n_estimators=cfg.settings["isoforest_n_estimators"],
|
||||||
|
max_samples=cfg.settings["isoforest_max_samples"],
|
||||||
|
contamination=cfg.settings["isoforest_contamination"],
|
||||||
|
n_jobs=cfg.settings["isoforest_n_jobs_model"],
|
||||||
|
seed=cfg.settings["seed"],
|
||||||
|
)
|
||||||
|
Isoforest.load_model(import_path=isoforest_model_path, device=device)
|
||||||
|
Isoforest.test(
|
||||||
|
dataset,
|
||||||
|
device=device,
|
||||||
|
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||||
|
k_fold_idx=fold_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize DeepSAD model and set neural network phi
|
||||||
|
if retest_deepsad:
|
||||||
|
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
|
||||||
|
deepSAD.set_network(cfg.settings["net_name"])
|
||||||
|
deepSAD.load_model(
|
||||||
|
model_path=deepsad_model_path, load_ae=True, map_location=device
|
||||||
|
)
|
||||||
|
logger.info("Loading model from %s." % load_model)
|
||||||
|
deepSAD.test(
|
||||||
|
dataset,
|
||||||
|
device=device,
|
||||||
|
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||||
|
k_fold_idx=fold_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if retest_ocsvm:
|
||||||
|
ocsvm = OCSVM(
|
||||||
|
kernel=cfg.settings["ocsvm_kernel"],
|
||||||
|
nu=cfg.settings["ocsvm_nu"],
|
||||||
|
hybrid=True,
|
||||||
|
latent_space_dim=cfg.settings["latent_space_dim"],
|
||||||
|
)
|
||||||
|
ocsvm.load_ae(
|
||||||
|
net_name=cfg.settings["net_name"],
|
||||||
|
model_path=ae_model_path,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
ocsvm.load_model(import_path=ocsvm_model_path)
|
||||||
|
ocsvm.test(
|
||||||
|
dataset,
|
||||||
|
device=device,
|
||||||
|
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||||
|
k_fold_idx=fold_idx,
|
||||||
|
batch_size=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
retest_output_path = load_model / "retest_output"
|
||||||
|
retest_output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save results, model, and configuration
|
||||||
|
if fold_idx is None:
|
||||||
|
if retest_deepsad:
|
||||||
|
deepSAD.save_results(
|
||||||
|
export_pkl=retest_output_path / "results_deepsad.pkl"
|
||||||
|
)
|
||||||
|
if retest_ocsvm:
|
||||||
|
ocsvm.save_results(
|
||||||
|
export_pkl=retest_output_path / "results_ocsvm.pkl"
|
||||||
|
)
|
||||||
|
if retest_isoforest:
|
||||||
|
Isoforest.save_results(
|
||||||
|
export_pkl=retest_output_path / "results_isoforest.pkl"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if retest_deepsad:
|
||||||
|
deepSAD.save_results(
|
||||||
|
export_pkl=retest_output_path
|
||||||
|
/ f"results_deepsad_{fold_idx}.pkl"
|
||||||
|
)
|
||||||
|
if retest_ocsvm:
|
||||||
|
ocsvm.save_results(
|
||||||
|
export_pkl=retest_output_path / f"/results_ocsvm_{fold_idx}.pkl"
|
||||||
|
)
|
||||||
|
if retest_isoforest:
|
||||||
|
Isoforest.save_results(
|
||||||
|
export_pkl=retest_output_path
|
||||||
|
/ f"/results_isoforest_{fold_idx}.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user