This commit is contained in:
Jan Kowalczyk
2025-06-10 09:31:28 +02:00
parent 3538b15073
commit 156b6d2ac1
8 changed files with 794 additions and 580 deletions

View File

@@ -1,6 +1,7 @@
import json
import pickle
import numpy as np
import torch
from base.base_dataset import BaseADDataset
@@ -43,10 +44,47 @@ class DeepSAD(object):
self.ae_optimizer_name = None
self.results = {
"train_time": None,
"test_auc": None,
"test_time": None,
"test_scores": None,
"train": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
},
"test": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
"exp_based": {
"auc": None,
"roc": None,
"prc": None,
"ap": None,
"scores": None,
"indices": None,
"labels": None,
"valid_mask": None,
},
"manual_based": {
"auc": None,
"roc": None,
"prc": None,
"ap": None,
"scores": None,
"indices": None,
"labels": None,
"valid_mask": None,
},
},
"inference": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
},
}
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
@@ -86,8 +124,17 @@ class DeepSAD(object):
)
# Get the model
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
self.results["train_time"] = self.trainer.train_time
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
# Store training results including indices
self.results["train"]["time"] = self.trainer.train_time
self.results["train"]["indices"] = self.trainer.train_indices
self.results["train"]["file_ids"] = self.trainer.train_file_ids
self.results["train"]["frame_ids"] = self.trainer.train_frame_ids
# Get file names mapping for training data
self.results["train"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.train_file_ids)
}
def inference(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
@@ -99,7 +146,21 @@ class DeepSAD(object):
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
)
return self.trainer.infer(dataset, self.net)
scores, outputs = self.trainer.infer(dataset, self.net)
# Store inference indices and mappings
self.results["inference"]["time"] = self.trainer.inference_time
self.results["inference"]["indices"] = self.trainer.inference_indices
self.results["inference"]["file_ids"] = self.trainer.inference_file_ids
self.results["inference"]["frame_ids"] = self.trainer.inference_frame_ids
# Get file names mapping for inference data
self.results["inference"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.inference_file_ids)
}
return scores, outputs
def test(
self,
@@ -117,13 +178,51 @@ class DeepSAD(object):
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
# Get results
self.results["test_auc"] = self.trainer.test_auc
self.results["test_roc"] = self.trainer.test_roc
self.results["test_prc"] = self.trainer.test_prc
self.results["test_ap"] = self.trainer.test_ap
self.results["test_time"] = self.trainer.test_time
self.results["test_scores"] = self.trainer.test_scores
# Store all test indices and mappings
self.results["test"]["time"] = self.trainer.test_time
self.results["test"]["indices"] = self.trainer.test_indices
self.results["test"]["file_ids"] = self.trainer.test_file_ids
self.results["test"]["frame_ids"] = self.trainer.test_frame_ids
# Get file names mapping for test data
self.results["test"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.test_file_ids)
}
# Store experiment-based results
self.results["test"]["exp_based"]["auc"] = self.trainer.test_auc_exp_based
self.results["test"]["exp_based"]["roc"] = self.trainer.test_roc_exp_based
self.results["test"]["exp_based"]["prc"] = self.trainer.test_prc_exp_based
self.results["test"]["exp_based"]["ap"] = self.trainer.test_ap_exp_based
self.results["test"]["exp_based"]["scores"] = self.trainer.test_scores_exp_based
self.results["test"]["exp_based"]["indices"] = self.trainer.test_index_mapping[
"exp_based"
]["indices"]
self.results["test"]["exp_based"]["labels"] = self.trainer.test_index_mapping[
"exp_based"
]["labels"]
self.results["test"]["exp_based"]["valid_mask"] = (
self.trainer.test_index_mapping["exp_based"]["valid_mask"]
)
# Store manual-based results
self.results["test"]["manual_based"]["auc"] = self.trainer.test_auc_manual_based
self.results["test"]["manual_based"]["roc"] = self.trainer.test_roc_manual_based
self.results["test"]["manual_based"]["prc"] = self.trainer.test_prc_manual_based
self.results["test"]["manual_based"]["ap"] = self.trainer.test_ap_manual_based
self.results["test"]["manual_based"]["scores"] = (
self.trainer.test_scores_manual_based
)
self.results["test"]["manual_based"]["indices"] = (
self.trainer.test_index_mapping["manual_based"]["indices"]
)
self.results["test"]["manual_based"]["labels"] = (
self.trainer.test_index_mapping["manual_based"]["labels"]
)
self.results["test"]["manual_based"]["valid_mask"] = (
self.trainer.test_index_mapping["manual_based"]["valid_mask"]
)
def pretrain(
self,