retest implemented and fixed missing center in save data
This commit is contained in:
@@ -126,6 +126,8 @@ class DeepSAD(object):
|
||||
)
|
||||
# Get the model
|
||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
|
||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||
# Store training results including indices
|
||||
self.results["train"]["time"] = self.trainer.train_time
|
||||
self.results["train"]["indices"] = self.trainer.train_indices
|
||||
@@ -333,7 +335,7 @@ class DeepSAD(object):
|
||||
# load autoencoder parameters if specified
|
||||
if load_ae:
|
||||
if self.ae_net is None:
|
||||
self.ae_net = build_autoencoder(self.net_name)
|
||||
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
|
||||
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
||||
|
||||
def save_results(self, export_pkl):
|
||||
|
||||
@@ -25,7 +25,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
[
|
||||
"train",
|
||||
"infer",
|
||||
"ae_elbow_test", # Add new action
|
||||
"ae_elbow_test",
|
||||
"retest",
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -773,6 +774,165 @@ def main(
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown action: {action}")
|
||||
elif action == "retest":
|
||||
if (
|
||||
not load_model
|
||||
or not Path(load_model).exists()
|
||||
or not Path(load_model).is_dir()
|
||||
):
|
||||
logger.error(
|
||||
"For retest mode a model directory has to be loaded! Pass the --load_model option with the model directory path!"
|
||||
)
|
||||
return
|
||||
load_model = Path(load_model)
|
||||
if not load_config:
|
||||
logger.error(
|
||||
"For retest mode a config has to be loaded! Pass the --load_config option with the config path!"
|
||||
)
|
||||
return
|
||||
|
||||
dataset = load_dataset(
|
||||
cfg.settings["dataset_name"],
|
||||
data_path,
|
||||
cfg.settings["normal_class"],
|
||||
cfg.settings["known_outlier_class"],
|
||||
cfg.settings["n_known_outlier_classes"],
|
||||
cfg.settings["ratio_known_normal"],
|
||||
cfg.settings["ratio_known_outlier"],
|
||||
cfg.settings["ratio_pollution"],
|
||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||
k_fold_num=cfg.settings["k_fold_num"],
|
||||
num_known_normal=cfg.settings["num_known_normal"],
|
||||
num_known_outlier=cfg.settings["num_known_outlier"],
|
||||
)
|
||||
|
||||
train_passes = (
|
||||
range(cfg.settings["k_fold_num"]) if cfg.settings["k_fold"] else [None]
|
||||
)
|
||||
|
||||
retest_isoforest = True
|
||||
retest_ocsvm = True
|
||||
retest_deepsad = True
|
||||
|
||||
for fold_idx in train_passes:
|
||||
if fold_idx is None:
|
||||
logger.info("Single train re-testing without k-fold")
|
||||
deepsad_model_path = load_model / "model_deepsad.tar"
|
||||
isoforest_model_path = load_model / "model_ocsvm.pkl"
|
||||
ocsvm_model_path = load_model / "model_isoforest.pkl"
|
||||
ae_model_path = load_model / "model_ae.tar"
|
||||
else:
|
||||
logger.info(f"Fold {fold_idx + 1}/{cfg.settings['k_fold_num']}")
|
||||
|
||||
deepsad_model_path = load_model / f"model_deepsad_{fold_idx}.tar"
|
||||
isoforest_model_path = load_model / f"model_isoforest_{fold_idx}.pkl"
|
||||
ocsvm_model_path = load_model / f"model_ocsvm_{fold_idx}.pkl"
|
||||
ae_model_path = load_model / f"model_ae_{fold_idx}.tar"
|
||||
|
||||
# Check which model files exist and which do not
|
||||
model_paths = [
|
||||
deepsad_model_path,
|
||||
isoforest_model_path,
|
||||
ocsvm_model_path,
|
||||
ae_model_path,
|
||||
]
|
||||
missing_models = [
|
||||
path.name
|
||||
for path in model_paths
|
||||
if not path.exists() or not path.is_file()
|
||||
]
|
||||
if missing_models:
|
||||
logger.error(
|
||||
f"The following model files do not exist: {', '.join(missing_models)}. Please check the load_model path."
|
||||
)
|
||||
return
|
||||
|
||||
# Initialize Isolation Forest model
|
||||
if retest_isoforest:
|
||||
Isoforest = IsoForest(
|
||||
hybrid=False,
|
||||
n_estimators=cfg.settings["isoforest_n_estimators"],
|
||||
max_samples=cfg.settings["isoforest_max_samples"],
|
||||
contamination=cfg.settings["isoforest_contamination"],
|
||||
n_jobs=cfg.settings["isoforest_n_jobs_model"],
|
||||
seed=cfg.settings["seed"],
|
||||
)
|
||||
Isoforest.load_model(import_path=isoforest_model_path, device=device)
|
||||
Isoforest.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Initialize DeepSAD model and set neural network phi
|
||||
if retest_deepsad:
|
||||
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
|
||||
deepSAD.set_network(cfg.settings["net_name"])
|
||||
deepSAD.load_model(
|
||||
model_path=deepsad_model_path, load_ae=True, map_location=device
|
||||
)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
deepSAD.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
if retest_ocsvm:
|
||||
ocsvm = OCSVM(
|
||||
kernel=cfg.settings["ocsvm_kernel"],
|
||||
nu=cfg.settings["ocsvm_nu"],
|
||||
hybrid=True,
|
||||
latent_space_dim=cfg.settings["latent_space_dim"],
|
||||
)
|
||||
ocsvm.load_ae(
|
||||
net_name=cfg.settings["net_name"],
|
||||
model_path=ae_model_path,
|
||||
device=device,
|
||||
)
|
||||
ocsvm.load_model(import_path=ocsvm_model_path)
|
||||
ocsvm.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
batch_size=256,
|
||||
)
|
||||
|
||||
retest_output_path = load_model / "retest_output"
|
||||
retest_output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save results, model, and configuration
|
||||
if fold_idx is None:
|
||||
if retest_deepsad:
|
||||
deepSAD.save_results(
|
||||
export_pkl=retest_output_path / "results_deepsad.pkl"
|
||||
)
|
||||
if retest_ocsvm:
|
||||
ocsvm.save_results(
|
||||
export_pkl=retest_output_path / "results_ocsvm.pkl"
|
||||
)
|
||||
if retest_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=retest_output_path / "results_isoforest.pkl"
|
||||
)
|
||||
else:
|
||||
if retest_deepsad:
|
||||
deepSAD.save_results(
|
||||
export_pkl=retest_output_path
|
||||
/ f"results_deepsad_{fold_idx}.pkl"
|
||||
)
|
||||
if retest_ocsvm:
|
||||
ocsvm.save_results(
|
||||
export_pkl=retest_output_path / f"/results_ocsvm_{fold_idx}.pkl"
|
||||
)
|
||||
if retest_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=retest_output_path
|
||||
/ f"/results_isoforest_{fold_idx}.pkl"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user