full upload so not to lose anything important
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
from sklearn.model_selection import KFold
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Subset
|
||||
|
||||
from .base_dataset import BaseADDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class TorchvisionDataset(BaseADDataset):
|
||||
"""TorchvisionDataset class for datasets already implemented in torchvision.datasets."""
|
||||
|
||||
def __init__(self, root: str):
|
||||
def __init__(self, root: str, k_fold_number: int = 5):
|
||||
super().__init__(root)
|
||||
self.k_fold_number = k_fold_number
|
||||
self.fold_indices = None
|
||||
|
||||
def loaders(
|
||||
self,
|
||||
@@ -50,3 +54,43 @@ class TorchvisionDataset(BaseADDataset):
|
||||
else None
|
||||
)
|
||||
return train_loader, test_loader, inference_loader
|
||||
|
||||
def loaders_k_fold(
|
||||
self,
|
||||
fold_idx: int,
|
||||
batch_size: int,
|
||||
shuffle_train=True,
|
||||
shuffle_test=False,
|
||||
num_workers: int = 0,
|
||||
) -> (DataLoader, DataLoader):
|
||||
if self.fold_indices is None:
|
||||
# Define the K-fold Cross Validator
|
||||
kfold = KFold(n_splits=self.k_fold_number, shuffle=False)
|
||||
self.fold_indices = []
|
||||
# Generate indices for each fold and store them in a list
|
||||
for train_indices, val_indices in kfold.split(self.data_set):
|
||||
self.fold_indices.append((train_indices, val_indices))
|
||||
|
||||
train_loader = (
|
||||
DataLoader(
|
||||
dataset=Subset(self.data_set, self.fold_indices[fold_idx][0]),
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_train,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
)
|
||||
if self.data_set is not None
|
||||
else None
|
||||
)
|
||||
test_loader = (
|
||||
DataLoader(
|
||||
dataset=Subset(self.data_set, self.fold_indices[fold_idx][1]),
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle_test,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
if self.data_set is not None
|
||||
else None
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
Reference in New Issue
Block a user