implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -14,19 +14,39 @@ class TorchvisionDataset(BaseADDataset):
shuffle_train=True,
shuffle_test=False,
num_workers: int = 0,
) -> (DataLoader, DataLoader):
train_loader = DataLoader(
dataset=self.train_set,
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
drop_last=True,
) -> (DataLoader, DataLoader, DataLoader):
train_loader = (
DataLoader(
dataset=self.train_set,
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
drop_last=True,
)
if self.train_set is not None
else None
)
test_loader = DataLoader(
dataset=self.test_set,
batch_size=batch_size,
shuffle=shuffle_test,
num_workers=num_workers,
drop_last=False,
test_loader = (
DataLoader(
dataset=self.test_set,
batch_size=batch_size,
shuffle=shuffle_test,
num_workers=num_workers,
drop_last=False,
)
if self.test_set is not None
else None
)
return train_loader, test_loader
inference_loader = (
DataLoader(
dataset=self.inference_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=False,
)
if self.inference_set is not None
else None
)
return train_loader, test_loader, inference_loader