wip
This commit is contained in:
@@ -63,6 +63,8 @@ class TorchvisionDataset(BaseADDataset):
|
||||
shuffle_test=False,
|
||||
num_workers: int = 0,
|
||||
) -> (DataLoader, DataLoader):
|
||||
if self.k_fold_number is None:
|
||||
raise ValueError("k_fold_number must be set to a positive integer.")
|
||||
if self.fold_indices is None:
|
||||
# Define the K-fold Cross Validator
|
||||
kfold = KFold(n_splits=self.k_fold_number, shuffle=False)
|
||||
|
||||
Reference in New Issue
Block a user