implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -54,7 +54,7 @@ class DeepSADTrainer(BaseTrainer):
logger = logging.getLogger()
# Get train data loader
train_loader, _ = dataset.loaders(
train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
@@ -130,11 +130,49 @@ class DeepSADTrainer(BaseTrainer):
return net
def infer(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get test data loader
_, _, inference_loader = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
# Set device for network
net = net.to(self.device)
# Testing
logger.info("Starting inference...")
n_batches = 0
start_time = time.time()
scores = []
net.eval()
with torch.no_grad():
for data in inference_loader:
inputs, idx = data
inputs = inputs.to(self.device)
idx = idx.to(self.device)
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
scores += dist.cpu().data.numpy().tolist()
n_batches += 1
self.inference_time = time.time() - start_time
# Log results
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
logger.info("Finished inference.")
return np.array(scores)
def test(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get test data loader
_, test_loader = dataset.loaders(
_, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)