added deepsad base code
This commit is contained in:
26
Deep-SAD-PyTorch/src/base/base_dataset.py
Normal file
26
Deep-SAD-PyTorch/src/base/base_dataset.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class BaseADDataset(ABC):
|
||||
"""Anomaly detection dataset base class."""
|
||||
|
||||
def __init__(self, root: str):
|
||||
super().__init__()
|
||||
self.root = root # root path to data
|
||||
|
||||
self.n_classes = 2 # 0: normal, 1: outlier
|
||||
self.normal_classes = None # tuple with original class labels that define the normal class
|
||||
self.outlier_classes = None # tuple with original class labels that define the outlier class
|
||||
|
||||
self.train_set = None # must be of type torch.utils.data.Dataset
|
||||
self.test_set = None # must be of type torch.utils.data.Dataset
|
||||
|
||||
@abstractmethod
|
||||
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
|
||||
DataLoader, DataLoader):
|
||||
"""Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
Reference in New Issue
Block a user