implemented inference
This commit is contained in:
@@ -54,7 +54,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Get train data loader
|
||||
train_loader, _ = dataset.loaders(
|
||||
train_loader, _, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
|
||||
@@ -130,11 +130,49 @@ class DeepSADTrainer(BaseTrainer):
|
||||
|
||||
return net
|
||||
|
||||
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Get test data loader
|
||||
_, _, inference_loader = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Set device for network
|
||||
net = net.to(self.device)
|
||||
|
||||
# Testing
|
||||
logger.info("Starting inference...")
|
||||
n_batches = 0
|
||||
start_time = time.time()
|
||||
scores = []
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
for data in inference_loader:
|
||||
inputs, idx = data
|
||||
|
||||
inputs = inputs.to(self.device)
|
||||
idx = idx.to(self.device)
|
||||
|
||||
outputs = net(inputs)
|
||||
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
||||
scores += dist.cpu().data.numpy().tolist()
|
||||
|
||||
n_batches += 1
|
||||
|
||||
self.inference_time = time.time() - start_time
|
||||
|
||||
# Log results
|
||||
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
||||
logger.info("Finished inference.")
|
||||
|
||||
return np.array(scores)
|
||||
|
||||
def test(self, dataset: BaseADDataset, net: BaseNet):
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Get test data loader
|
||||
_, test_loader = dataset.loaders(
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user