wip inference
This commit is contained in:
@@ -638,57 +638,185 @@ def main(
|
||||
cfg.save_config(export_json=xp_path + "/config.json")
|
||||
|
||||
elif action == "infer":
|
||||
# Inference uses a deterministic, non-shuffled loader to preserve temporal order
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
cfg.settings["dataset_name"],
|
||||
data_path,
|
||||
normal_class,
|
||||
known_outlier_class,
|
||||
n_known_outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
cfg.settings["normal_class"],
|
||||
cfg.settings["known_outlier_class"],
|
||||
cfg.settings["n_known_outlier_classes"],
|
||||
cfg.settings["ratio_known_normal"],
|
||||
cfg.settings["ratio_known_outlier"],
|
||||
cfg.settings["ratio_pollution"],
|
||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||
k_fold_num=False,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
# Log random sample of known anomaly classes if more than 1 class
|
||||
if n_known_outlier_classes > 1:
|
||||
logger.info("Known anomaly classes: %s" % (dataset.known_outlier_classes,))
|
||||
|
||||
# Initialize DeepSAD model and set neural network phi
|
||||
deepSAD = DeepSAD(latent_space_dim, cfg.settings["eta"])
|
||||
deepSAD.set_network(net_name)
|
||||
|
||||
# If specified, load Deep SAD model (center c, network weights, and possibly autoencoder weights)
|
||||
if not load_model:
|
||||
# --- Expect a model DIRECTORY (aligned with 'retest') ---
|
||||
if (
|
||||
(not load_model)
|
||||
or (not Path(load_model).exists())
|
||||
or (not Path(load_model).is_dir())
|
||||
):
|
||||
logger.error(
|
||||
"For inference mode a model has to be loaded! Pass the --load_model option with the model path!"
|
||||
"For inference mode a model directory has to be loaded! "
|
||||
"Pass the --load_model option with the model directory path!"
|
||||
)
|
||||
return
|
||||
load_model = Path(load_model)
|
||||
|
||||
# Resolve expected model artifacts (single-model / no k-fold suffixes)
|
||||
deepsad_model_path = load_model / "model_deepsad.tar"
|
||||
ae_model_path = load_model / "model_ae.tar"
|
||||
ocsvm_model_path = load_model / "model_ocsvm.pkl"
|
||||
isoforest_model_path = load_model / "model_isoforest.pkl"
|
||||
|
||||
# Sanity check model files exist
|
||||
model_paths = [
|
||||
deepsad_model_path,
|
||||
ae_model_path,
|
||||
ocsvm_model_path,
|
||||
isoforest_model_path,
|
||||
]
|
||||
missing = [p.name for p in model_paths if not p.exists() or not p.is_file()]
|
||||
if missing:
|
||||
logger.error(
|
||||
"The following model files do not exist in the provided model directory: "
|
||||
+ ", ".join(missing)
|
||||
)
|
||||
return
|
||||
|
||||
deepSAD.load_model(model_path=load_model, load_ae=True, map_location=device)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
# Prepare output paths
|
||||
inf_dir = Path(xp_path) / "inference"
|
||||
inf_dir.mkdir(parents=True, exist_ok=True)
|
||||
base_stem = Path(Path(dataset.root).stem) # keep your previous naming
|
||||
# DeepSAD outputs (keep legacy filenames for backward compatibility)
|
||||
deepsad_scores_path = inf_dir / Path(
|
||||
base_stem.stem + "_deepsad_scores"
|
||||
).with_suffix(".npy")
|
||||
deepsad_outputs_path = inf_dir / Path(base_stem.stem + "_outputs").with_suffix(
|
||||
".npy"
|
||||
)
|
||||
# Baselines
|
||||
ocsvm_scores_path = inf_dir / Path(
|
||||
base_stem.stem + "_ocsvm_scores"
|
||||
).with_suffix(".npy")
|
||||
isoforest_scores_path = inf_dir / Path(
|
||||
base_stem.stem + "_isoforest_scores"
|
||||
).with_suffix(".npy")
|
||||
|
||||
inference_results, all_outputs = deepSAD.inference(
|
||||
dataset, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||
)
|
||||
inference_results_path = (
|
||||
Path(xp_path)
|
||||
/ "inference"
|
||||
/ Path(Path(dataset.root).stem).with_suffix(".npy")
|
||||
)
|
||||
inference_outputs_path = (
|
||||
Path(xp_path)
|
||||
/ "inference"
|
||||
/ Path(Path(dataset.root).stem + "_outputs").with_suffix(".npy")
|
||||
# Common loader settings
|
||||
_n_jobs = (
|
||||
n_jobs_dataloader
|
||||
if "n_jobs_dataloader" in locals()
|
||||
else cfg.settings.get("n_jobs_dataloader", 0)
|
||||
)
|
||||
|
||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.save(inference_results_path, inference_results, fix_imports=False)
|
||||
np.save(inference_outputs_path, all_outputs, fix_imports=False)
|
||||
# ----------------- DeepSAD -----------------
|
||||
|
||||
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
|
||||
deepSAD.set_network(cfg.settings["net_name"])
|
||||
deepSAD.load_model(
|
||||
model_path=deepsad_model_path, load_ae=True, map_location=device
|
||||
)
|
||||
logger.info("Loaded DeepSAD model from %s.", deepsad_model_path)
|
||||
|
||||
deepsad_scores, deepsad_all_outputs = deepSAD.inference(
|
||||
dataset, device=device, n_jobs_dataloader=_n_jobs
|
||||
)
|
||||
|
||||
np.save(deepsad_scores_path, deepsad_scores)
|
||||
# np.save(deepsad_outputs_path, deepsad_all_outputs)
|
||||
|
||||
logger.info(
|
||||
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
|
||||
"DeepSAD inference: median=%.6f mean=%.6f min=%.6f max=%.6f",
|
||||
float(np.median(deepsad_scores)),
|
||||
float(np.mean(deepsad_scores)),
|
||||
float(np.min(deepsad_scores)),
|
||||
float(np.max(deepsad_scores)),
|
||||
)
|
||||
|
||||
# ----------------- OCSVM (hybrid) -----------------
|
||||
ocsvm_scores = None
|
||||
ocsvm = OCSVM(
|
||||
kernel=cfg.settings["ocsvm_kernel"],
|
||||
nu=cfg.settings["ocsvm_nu"],
|
||||
hybrid=True,
|
||||
latent_space_dim=cfg.settings["latent_space_dim"],
|
||||
)
|
||||
# load AE to build the feature extractor for hybrid OCSVM
|
||||
ocsvm.load_ae(
|
||||
net_name=cfg.settings["net_name"],
|
||||
model_path=ae_model_path,
|
||||
device=device,
|
||||
)
|
||||
ocsvm.load_model(import_path=ocsvm_model_path)
|
||||
|
||||
ocsvm_scores = ocsvm.inference(
|
||||
dataset, device=device, n_jobs_dataloader=_n_jobs, batch_size=32
|
||||
)
|
||||
|
||||
if ocsvm_scores is not None:
|
||||
np.save(ocsvm_scores_path, ocsvm_scores)
|
||||
logger.info(
|
||||
"OCSVM inference: median=%.6f mean=%.6f min=%.6f max=%.6f",
|
||||
float(np.median(ocsvm_scores)),
|
||||
float(np.mean(ocsvm_scores)),
|
||||
float(np.min(ocsvm_scores)),
|
||||
float(np.max(ocsvm_scores)),
|
||||
)
|
||||
else:
|
||||
logger.warning("OCSVM scores could not be determined; no array saved.")
|
||||
|
||||
# ----------------- Isolation Forest -----------------
|
||||
isoforest_scores = None
|
||||
Isoforest = IsoForest(
|
||||
hybrid=False,
|
||||
n_estimators=cfg.settings["isoforest_n_estimators"],
|
||||
max_samples=cfg.settings["isoforest_max_samples"],
|
||||
contamination=cfg.settings["isoforest_contamination"],
|
||||
n_jobs=cfg.settings["isoforest_n_jobs_model"],
|
||||
seed=cfg.settings["seed"],
|
||||
)
|
||||
Isoforest.load_model(import_path=isoforest_model_path, device=device)
|
||||
isoforest_scores = Isoforest.inference(
|
||||
dataset, device=device, n_jobs_dataloader=_n_jobs
|
||||
)
|
||||
if isoforest_scores is not None:
|
||||
np.save(isoforest_scores_path, isoforest_scores)
|
||||
logger.info(
|
||||
"IsolationForest inference: median=%.6f mean=%.6f min=%.6f max=%.6f",
|
||||
float(np.median(isoforest_scores)),
|
||||
float(np.mean(isoforest_scores)),
|
||||
float(np.min(isoforest_scores)),
|
||||
float(np.max(isoforest_scores)),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Isolation Forest scores could not be determined; no array saved."
|
||||
)
|
||||
|
||||
# Final summary (DeepSAD always runs; baselines are best-effort)
|
||||
logger.info(
|
||||
"Inference complete. Saved arrays to %s:\n"
|
||||
" DeepSAD scores: %s\n"
|
||||
" DeepSAD outputs: %s\n"
|
||||
" OCSVM scores: %s\n"
|
||||
" IsoForest scores: %s",
|
||||
inf_dir,
|
||||
deepsad_scores_path.name,
|
||||
deepsad_outputs_path.name,
|
||||
ocsvm_scores_path.name if ocsvm_scores is not None else "(not saved)",
|
||||
isoforest_scores_path.name
|
||||
if isoforest_scores is not None
|
||||
else "(not saved)",
|
||||
)
|
||||
|
||||
elif action == "ae_elbow_test":
|
||||
# Load data once
|
||||
dataset = load_dataset(
|
||||
|
||||
Reference in New Issue
Block a user