retest implemented and fixed missing center in save data

This commit is contained in:
Jan Kowalczyk
2025-07-01 17:22:29 +02:00
parent 24c6771576
commit 4863b91127
6 changed files with 189 additions and 1307 deletions

View File

@@ -126,6 +126,8 @@ class DeepSAD(object):
)
# Get the model
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
# Store training results including indices
self.results["train"]["time"] = self.trainer.train_time
self.results["train"]["indices"] = self.trainer.train_indices
@@ -333,7 +335,7 @@ class DeepSAD(object):
# load autoencoder parameters if specified
if load_ae:
if self.ae_net is None:
self.ae_net = build_autoencoder(self.net_name)
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
def save_results(self, export_pkl):

View File

@@ -25,7 +25,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
[
"train",
"infer",
"ae_elbow_test", # Add new action
"ae_elbow_test",
"retest",
]
),
)
@@ -773,6 +774,165 @@ def main(
)
else:
logger.error(f"Unknown action: {action}")
elif action == "retest":
if (
not load_model
or not Path(load_model).exists()
or not Path(load_model).is_dir()
):
logger.error(
"For retest mode a model directory has to be loaded! Pass the --load_model option with the model directory path!"
)
return
load_model = Path(load_model)
if not load_config:
logger.error(
"For retest mode a config has to be loaded! Pass the --load_config option with the config path!"
)
return
dataset = load_dataset(
cfg.settings["dataset_name"],
data_path,
cfg.settings["normal_class"],
cfg.settings["known_outlier_class"],
cfg.settings["n_known_outlier_classes"],
cfg.settings["ratio_known_normal"],
cfg.settings["ratio_known_outlier"],
cfg.settings["ratio_pollution"],
random_state=np.random.RandomState(cfg.settings["seed"]),
k_fold_num=cfg.settings["k_fold_num"],
num_known_normal=cfg.settings["num_known_normal"],
num_known_outlier=cfg.settings["num_known_outlier"],
)
train_passes = (
range(cfg.settings["k_fold_num"]) if cfg.settings["k_fold"] else [None]
)
retest_isoforest = True
retest_ocsvm = True
retest_deepsad = True
for fold_idx in train_passes:
if fold_idx is None:
logger.info("Single train re-testing without k-fold")
deepsad_model_path = load_model / "model_deepsad.tar"
isoforest_model_path = load_model / "model_ocsvm.pkl"
ocsvm_model_path = load_model / "model_isoforest.pkl"
ae_model_path = load_model / "model_ae.tar"
else:
logger.info(f"Fold {fold_idx + 1}/{cfg.settings['k_fold_num']}")
deepsad_model_path = load_model / f"model_deepsad_{fold_idx}.tar"
isoforest_model_path = load_model / f"model_isoforest_{fold_idx}.pkl"
ocsvm_model_path = load_model / f"model_ocsvm_{fold_idx}.pkl"
ae_model_path = load_model / f"model_ae_{fold_idx}.tar"
# Check which model files exist and which do not
model_paths = [
deepsad_model_path,
isoforest_model_path,
ocsvm_model_path,
ae_model_path,
]
missing_models = [
path.name
for path in model_paths
if not path.exists() or not path.is_file()
]
if missing_models:
logger.error(
f"The following model files do not exist: {', '.join(missing_models)}. Please check the load_model path."
)
return
# Initialize Isolation Forest model
if retest_isoforest:
Isoforest = IsoForest(
hybrid=False,
n_estimators=cfg.settings["isoforest_n_estimators"],
max_samples=cfg.settings["isoforest_max_samples"],
contamination=cfg.settings["isoforest_contamination"],
n_jobs=cfg.settings["isoforest_n_jobs_model"],
seed=cfg.settings["seed"],
)
Isoforest.load_model(import_path=isoforest_model_path, device=device)
Isoforest.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
)
# Initialize DeepSAD model and set neural network phi
if retest_deepsad:
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
deepSAD.set_network(cfg.settings["net_name"])
deepSAD.load_model(
model_path=deepsad_model_path, load_ae=True, map_location=device
)
logger.info("Loading model from %s." % load_model)
deepSAD.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
)
if retest_ocsvm:
ocsvm = OCSVM(
kernel=cfg.settings["ocsvm_kernel"],
nu=cfg.settings["ocsvm_nu"],
hybrid=True,
latent_space_dim=cfg.settings["latent_space_dim"],
)
ocsvm.load_ae(
net_name=cfg.settings["net_name"],
model_path=ae_model_path,
device=device,
)
ocsvm.load_model(import_path=ocsvm_model_path)
ocsvm.test(
dataset,
device=device,
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
k_fold_idx=fold_idx,
batch_size=256,
)
retest_output_path = load_model / "retest_output"
retest_output_path.mkdir(parents=True, exist_ok=True)
# Save results, model, and configuration
if fold_idx is None:
if retest_deepsad:
deepSAD.save_results(
export_pkl=retest_output_path / "results_deepsad.pkl"
)
if retest_ocsvm:
ocsvm.save_results(
export_pkl=retest_output_path / "results_ocsvm.pkl"
)
if retest_isoforest:
Isoforest.save_results(
export_pkl=retest_output_path / "results_isoforest.pkl"
)
else:
if retest_deepsad:
deepSAD.save_results(
export_pkl=retest_output_path
/ f"results_deepsad_{fold_idx}.pkl"
)
if retest_ocsvm:
ocsvm.save_results(
export_pkl=retest_output_path / f"/results_ocsvm_{fold_idx}.pkl"
)
if retest_isoforest:
Isoforest.save_results(
export_pkl=retest_output_path
/ f"/results_isoforest_{fold_idx}.pkl"
)
if __name__ == "__main__":