full upload so not to lose anything important
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
from base.base_dataset import BaseADDataset
|
||||
from networks.main import build_network, build_autoencoder
|
||||
from optim.DeepSAD_trainer import DeepSADTrainer
|
||||
from networks.main import build_autoencoder, build_network
|
||||
from optim.ae_trainer import AETrainer
|
||||
from optim.DeepSAD_trainer import DeepSADTrainer
|
||||
|
||||
|
||||
class DeepSAD(object):
|
||||
@@ -65,6 +67,7 @@ class DeepSAD(object):
|
||||
weight_decay: float = 1e-6,
|
||||
device: str = "cuda",
|
||||
n_jobs_dataloader: int = 0,
|
||||
k_fold_idx: int = None,
|
||||
):
|
||||
"""Trains the Deep SAD model on the training data."""
|
||||
|
||||
@@ -82,7 +85,7 @@ class DeepSAD(object):
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
)
|
||||
# Get the model
|
||||
self.net = self.trainer.train(dataset, self.net)
|
||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
self.results["train_time"] = self.trainer.train_time
|
||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||
|
||||
@@ -99,7 +102,11 @@ class DeepSAD(object):
|
||||
return self.trainer.infer(dataset, self.net)
|
||||
|
||||
def test(
|
||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||
self,
|
||||
dataset: BaseADDataset,
|
||||
device: str = "cuda",
|
||||
n_jobs_dataloader: int = 0,
|
||||
k_fold_idx: int = None,
|
||||
):
|
||||
"""Tests the Deep SAD model on the test data."""
|
||||
|
||||
@@ -108,10 +115,13 @@ class DeepSAD(object):
|
||||
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||
)
|
||||
|
||||
self.trainer.test(dataset, self.net)
|
||||
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get results
|
||||
self.results["test_auc"] = self.trainer.test_auc
|
||||
self.results["test_roc"] = self.trainer.test_roc
|
||||
self.results["test_prc"] = self.trainer.test_prc
|
||||
self.results["test_ap"] = self.trainer.test_ap
|
||||
self.results["test_time"] = self.trainer.test_time
|
||||
self.results["test_scores"] = self.trainer.test_scores
|
||||
|
||||
@@ -126,6 +136,7 @@ class DeepSAD(object):
|
||||
weight_decay: float = 1e-6,
|
||||
device: str = "cuda",
|
||||
n_jobs_dataloader: int = 0,
|
||||
k_fold_idx: int = None,
|
||||
):
|
||||
"""Pretrains the weights for the Deep SAD network phi via autoencoder."""
|
||||
|
||||
@@ -144,13 +155,13 @@ class DeepSAD(object):
|
||||
device=device,
|
||||
n_jobs_dataloader=n_jobs_dataloader,
|
||||
)
|
||||
self.ae_net = self.ae_trainer.train(dataset, self.ae_net)
|
||||
self.ae_net = self.ae_trainer.train(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get train results
|
||||
self.ae_results["train_time"] = self.ae_trainer.train_time
|
||||
|
||||
# Test
|
||||
self.ae_trainer.test(dataset, self.ae_net)
|
||||
self.ae_trainer.test(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get test results
|
||||
self.ae_results["test_auc"] = self.ae_trainer.test_auc
|
||||
@@ -197,10 +208,11 @@ class DeepSAD(object):
|
||||
self.ae_net = build_autoencoder(self.net_name)
|
||||
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
||||
|
||||
def save_results(self, export_json):
|
||||
def save_results(self, export_pkl):
|
||||
"""Save results dict to a JSON-file."""
|
||||
with open(export_json, "w") as fp:
|
||||
json.dump(self.results, fp)
|
||||
with open(export_pkl, "wb") as fp:
|
||||
# json.dump(self.results, fp)
|
||||
pickle.dump(self.results, fp)
|
||||
|
||||
def save_ae_results(self, export_json):
|
||||
"""Save autoencoder results dict to a JSON-file."""
|
||||
|
||||
Reference in New Issue
Block a user