added deepsad base code
This commit is contained in:
54
Deep-SAD-PyTorch/src/datasets/main.py
Normal file
54
Deep-SAD-PyTorch/src/datasets/main.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from .mnist import MNIST_Dataset
|
||||
from .fmnist import FashionMNIST_Dataset
|
||||
from .cifar10 import CIFAR10_Dataset
|
||||
from .odds import ODDSADDataset
|
||||
|
||||
|
||||
def load_dataset(dataset_name, data_path, normal_class, known_outlier_class, n_known_outlier_classes: int = 0,
|
||||
ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0,
|
||||
random_state=None):
|
||||
"""Loads the dataset."""
|
||||
|
||||
implemented_datasets = ('mnist', 'fmnist', 'cifar10',
|
||||
'arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid')
|
||||
assert dataset_name in implemented_datasets
|
||||
|
||||
dataset = None
|
||||
|
||||
if dataset_name == 'mnist':
|
||||
dataset = MNIST_Dataset(root=data_path,
|
||||
normal_class=normal_class,
|
||||
known_outlier_class=known_outlier_class,
|
||||
n_known_outlier_classes=n_known_outlier_classes,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution)
|
||||
|
||||
if dataset_name == 'fmnist':
|
||||
dataset = FashionMNIST_Dataset(root=data_path,
|
||||
normal_class=normal_class,
|
||||
known_outlier_class=known_outlier_class,
|
||||
n_known_outlier_classes=n_known_outlier_classes,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution)
|
||||
|
||||
if dataset_name == 'cifar10':
|
||||
dataset = CIFAR10_Dataset(root=data_path,
|
||||
normal_class=normal_class,
|
||||
known_outlier_class=known_outlier_class,
|
||||
n_known_outlier_classes=n_known_outlier_classes,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution)
|
||||
|
||||
if dataset_name in ('arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid'):
|
||||
dataset = ODDSADDataset(root=data_path,
|
||||
dataset_name=dataset_name,
|
||||
n_known_outlier_classes=n_known_outlier_classes,
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution,
|
||||
random_state=random_state)
|
||||
|
||||
return dataset
|
||||
Reference in New Issue
Block a user