full upload so not to lose anything important

This commit is contained in:
Jan Kowalczyk
2025-03-14 18:02:23 +01:00
parent 35fcfb7d5a
commit b824ff7482
33 changed files with 3539 additions and 353 deletions

View File

@@ -1,12 +1,16 @@
from sklearn.model_selection import KFold
from torch.utils.data import ConcatDataset, DataLoader, Subset
from .base_dataset import BaseADDataset
from torch.utils.data import DataLoader
class TorchvisionDataset(BaseADDataset):
"""TorchvisionDataset class for datasets already implemented in torchvision.datasets."""
def __init__(self, root: str):
def __init__(self, root: str, k_fold_number: int = 5):
super().__init__(root)
self.k_fold_number = k_fold_number
self.fold_indices = None
def loaders(
self,
@@ -50,3 +54,43 @@ class TorchvisionDataset(BaseADDataset):
else None
)
return train_loader, test_loader, inference_loader
def loaders_k_fold(
self,
fold_idx: int,
batch_size: int,
shuffle_train=True,
shuffle_test=False,
num_workers: int = 0,
) -> (DataLoader, DataLoader):
if self.fold_indices is None:
# Define the K-fold Cross Validator
kfold = KFold(n_splits=self.k_fold_number, shuffle=False)
self.fold_indices = []
# Generate indices for each fold and store them in a list
for train_indices, val_indices in kfold.split(self.data_set):
self.fold_indices.append((train_indices, val_indices))
train_loader = (
DataLoader(
dataset=Subset(self.data_set, self.fold_indices[fold_idx][0]),
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
drop_last=True,
)
if self.data_set is not None
else None
)
test_loader = (
DataLoader(
dataset=Subset(self.data_set, self.fold_indices[fold_idx][1]),
batch_size=batch_size,
shuffle=shuffle_test,
num_workers=num_workers,
drop_last=False,
)
if self.data_set is not None
else None
)
return train_loader, test_loader