implemented inference
This commit is contained in:
@@ -14,19 +14,39 @@ class TorchvisionDataset(BaseADDataset):
|
||||
shuffle_train=True,
|
||||
shuffle_test=False,
|
||||
num_workers: int = 0,
|
||||
) -> (DataLoader, DataLoader):
|
||||
train_loader = DataLoader(
|
||||
dataset=self.train_set,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_train,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
) -> (DataLoader, DataLoader, DataLoader):
|
||||
train_loader = (
|
||||
DataLoader(
|
||||
dataset=self.train_set,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_train,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
)
|
||||
if self.train_set is not None
|
||||
else None
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
dataset=self.test_set,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_test,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
test_loader = (
|
||||
DataLoader(
|
||||
dataset=self.test_set,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_test,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
if self.test_set is not None
|
||||
else None
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
inference_loader = (
|
||||
DataLoader(
|
||||
dataset=self.inference_set,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
if self.inference_set is not None
|
||||
else None
|
||||
)
|
||||
return train_loader, test_loader, inference_loader
|
||||
|
||||
Reference in New Issue
Block a user