36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
from abc import ABC, abstractmethod
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class BaseADDataset(ABC):
|
|
"""Anomaly detection dataset base class."""
|
|
|
|
def __init__(self, root: str):
|
|
super().__init__()
|
|
self.root = root # root path to data
|
|
|
|
self.n_classes = 2 # 0: normal, 1: outlier
|
|
self.normal_classes = (
|
|
None # tuple with original class labels that define the normal class
|
|
)
|
|
self.outlier_classes = (
|
|
None # tuple with original class labels that define the outlier class
|
|
)
|
|
|
|
self.train_set = None # must be of type torch.utils.data.Dataset
|
|
self.test_set = None # must be of type torch.utils.data.Dataset
|
|
|
|
@abstractmethod
|
|
def loaders(
|
|
self,
|
|
batch_size: int,
|
|
shuffle_train=True,
|
|
shuffle_test=False,
|
|
num_workers: int = 0,
|
|
) -> (DataLoader, DataLoader):
|
|
"""Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
|
|
pass
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__
|