implemented inference
This commit is contained in:
@@ -86,6 +86,18 @@ class DeepSAD(object):
|
||||
self.results["train_time"] = self.trainer.train_time
|
||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||
|
||||
def inference(
|
||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||
):
|
||||
"""Tests the Deep SAD model on the test data."""
|
||||
|
||||
if self.trainer is None:
|
||||
self.trainer = DeepSADTrainer(
|
||||
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||
)
|
||||
|
||||
return self.trainer.infer(dataset, self.net)
|
||||
|
||||
def test(
|
||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user