wip
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from base.base_dataset import BaseADDataset
|
||||
@@ -43,10 +44,47 @@ class DeepSAD(object):
|
||||
self.ae_optimizer_name = None
|
||||
|
||||
self.results = {
|
||||
"train_time": None,
|
||||
"test_auc": None,
|
||||
"test_time": None,
|
||||
"test_scores": None,
|
||||
"train": {
|
||||
"time": None,
|
||||
"indices": None,
|
||||
"file_ids": None,
|
||||
"frame_ids": None,
|
||||
"file_names": None, # mapping of file_ids to file names
|
||||
},
|
||||
"test": {
|
||||
"time": None,
|
||||
"indices": None,
|
||||
"file_ids": None,
|
||||
"frame_ids": None,
|
||||
"file_names": None, # mapping of file_ids to file names
|
||||
"exp_based": {
|
||||
"auc": None,
|
||||
"roc": None,
|
||||
"prc": None,
|
||||
"ap": None,
|
||||
"scores": None,
|
||||
"indices": None,
|
||||
"labels": None,
|
||||
"valid_mask": None,
|
||||
},
|
||||
"manual_based": {
|
||||
"auc": None,
|
||||
"roc": None,
|
||||
"prc": None,
|
||||
"ap": None,
|
||||
"scores": None,
|
||||
"indices": None,
|
||||
"labels": None,
|
||||
"valid_mask": None,
|
||||
},
|
||||
},
|
||||
"inference": {
|
||||
"time": None,
|
||||
"indices": None,
|
||||
"file_ids": None,
|
||||
"frame_ids": None,
|
||||
"file_names": None, # mapping of file_ids to file names
|
||||
},
|
||||
}
|
||||
|
||||
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
||||
@@ -86,8 +124,17 @@ class DeepSAD(object):
|
||||
)
|
||||
# Get the model
|
||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
self.results["train_time"] = self.trainer.train_time
|
||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||
# Store training results including indices
|
||||
self.results["train"]["time"] = self.trainer.train_time
|
||||
self.results["train"]["indices"] = self.trainer.train_indices
|
||||
self.results["train"]["file_ids"] = self.trainer.train_file_ids
|
||||
self.results["train"]["frame_ids"] = self.trainer.train_frame_ids
|
||||
|
||||
# Get file names mapping for training data
|
||||
self.results["train"]["file_names"] = {
|
||||
file_id: dataset.get_file_name_from_idx(file_id)
|
||||
for file_id in np.unique(self.trainer.train_file_ids)
|
||||
}
|
||||
|
||||
def inference(
|
||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||
@@ -99,7 +146,21 @@ class DeepSAD(object):
|
||||
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||
)
|
||||
|
||||
return self.trainer.infer(dataset, self.net)
|
||||
scores, outputs = self.trainer.infer(dataset, self.net)
|
||||
|
||||
# Store inference indices and mappings
|
||||
self.results["inference"]["time"] = self.trainer.inference_time
|
||||
self.results["inference"]["indices"] = self.trainer.inference_indices
|
||||
self.results["inference"]["file_ids"] = self.trainer.inference_file_ids
|
||||
self.results["inference"]["frame_ids"] = self.trainer.inference_frame_ids
|
||||
|
||||
# Get file names mapping for inference data
|
||||
self.results["inference"]["file_names"] = {
|
||||
file_id: dataset.get_file_name_from_idx(file_id)
|
||||
for file_id in np.unique(self.trainer.inference_file_ids)
|
||||
}
|
||||
|
||||
return scores, outputs
|
||||
|
||||
def test(
|
||||
self,
|
||||
@@ -117,13 +178,51 @@ class DeepSAD(object):
|
||||
|
||||
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get results
|
||||
self.results["test_auc"] = self.trainer.test_auc
|
||||
self.results["test_roc"] = self.trainer.test_roc
|
||||
self.results["test_prc"] = self.trainer.test_prc
|
||||
self.results["test_ap"] = self.trainer.test_ap
|
||||
self.results["test_time"] = self.trainer.test_time
|
||||
self.results["test_scores"] = self.trainer.test_scores
|
||||
# Store all test indices and mappings
|
||||
self.results["test"]["time"] = self.trainer.test_time
|
||||
self.results["test"]["indices"] = self.trainer.test_indices
|
||||
self.results["test"]["file_ids"] = self.trainer.test_file_ids
|
||||
self.results["test"]["frame_ids"] = self.trainer.test_frame_ids
|
||||
|
||||
# Get file names mapping for test data
|
||||
self.results["test"]["file_names"] = {
|
||||
file_id: dataset.get_file_name_from_idx(file_id)
|
||||
for file_id in np.unique(self.trainer.test_file_ids)
|
||||
}
|
||||
|
||||
# Store experiment-based results
|
||||
self.results["test"]["exp_based"]["auc"] = self.trainer.test_auc_exp_based
|
||||
self.results["test"]["exp_based"]["roc"] = self.trainer.test_roc_exp_based
|
||||
self.results["test"]["exp_based"]["prc"] = self.trainer.test_prc_exp_based
|
||||
self.results["test"]["exp_based"]["ap"] = self.trainer.test_ap_exp_based
|
||||
self.results["test"]["exp_based"]["scores"] = self.trainer.test_scores_exp_based
|
||||
self.results["test"]["exp_based"]["indices"] = self.trainer.test_index_mapping[
|
||||
"exp_based"
|
||||
]["indices"]
|
||||
self.results["test"]["exp_based"]["labels"] = self.trainer.test_index_mapping[
|
||||
"exp_based"
|
||||
]["labels"]
|
||||
self.results["test"]["exp_based"]["valid_mask"] = (
|
||||
self.trainer.test_index_mapping["exp_based"]["valid_mask"]
|
||||
)
|
||||
|
||||
# Store manual-based results
|
||||
self.results["test"]["manual_based"]["auc"] = self.trainer.test_auc_manual_based
|
||||
self.results["test"]["manual_based"]["roc"] = self.trainer.test_roc_manual_based
|
||||
self.results["test"]["manual_based"]["prc"] = self.trainer.test_prc_manual_based
|
||||
self.results["test"]["manual_based"]["ap"] = self.trainer.test_ap_manual_based
|
||||
self.results["test"]["manual_based"]["scores"] = (
|
||||
self.trainer.test_scores_manual_based
|
||||
)
|
||||
self.results["test"]["manual_based"]["indices"] = (
|
||||
self.trainer.test_index_mapping["manual_based"]["indices"]
|
||||
)
|
||||
self.results["test"]["manual_based"]["labels"] = (
|
||||
self.trainer.test_index_mapping["manual_based"]["labels"]
|
||||
)
|
||||
self.results["test"]["manual_based"]["valid_mask"] = (
|
||||
self.trainer.test_index_mapping["manual_based"]["valid_mask"]
|
||||
)
|
||||
|
||||
def pretrain(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user