implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -86,6 +86,18 @@ class DeepSAD(object):
self.results["train_time"] = self.trainer.train_time
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
def inference(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
):
"""Tests the Deep SAD model on the test data."""
if self.trainer is None:
self.trainer = DeepSADTrainer(
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
)
return self.trainer.infer(dataset, self.net)
def test(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
):