55 lines
2.7 KiB
Python
55 lines
2.7 KiB
Python
|
|
from .mnist import MNIST_Dataset
|
||
|
|
from .fmnist import FashionMNIST_Dataset
|
||
|
|
from .cifar10 import CIFAR10_Dataset
|
||
|
|
from .odds import ODDSADDataset
|
||
|
|
|
||
|
|
|
||
|
|
def load_dataset(dataset_name, data_path, normal_class, known_outlier_class, n_known_outlier_classes: int = 0,
|
||
|
|
ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0,
|
||
|
|
random_state=None):
|
||
|
|
"""Loads the dataset."""
|
||
|
|
|
||
|
|
implemented_datasets = ('mnist', 'fmnist', 'cifar10',
|
||
|
|
'arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid')
|
||
|
|
assert dataset_name in implemented_datasets
|
||
|
|
|
||
|
|
dataset = None
|
||
|
|
|
||
|
|
if dataset_name == 'mnist':
|
||
|
|
dataset = MNIST_Dataset(root=data_path,
|
||
|
|
normal_class=normal_class,
|
||
|
|
known_outlier_class=known_outlier_class,
|
||
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
||
|
|
ratio_known_normal=ratio_known_normal,
|
||
|
|
ratio_known_outlier=ratio_known_outlier,
|
||
|
|
ratio_pollution=ratio_pollution)
|
||
|
|
|
||
|
|
if dataset_name == 'fmnist':
|
||
|
|
dataset = FashionMNIST_Dataset(root=data_path,
|
||
|
|
normal_class=normal_class,
|
||
|
|
known_outlier_class=known_outlier_class,
|
||
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
||
|
|
ratio_known_normal=ratio_known_normal,
|
||
|
|
ratio_known_outlier=ratio_known_outlier,
|
||
|
|
ratio_pollution=ratio_pollution)
|
||
|
|
|
||
|
|
if dataset_name == 'cifar10':
|
||
|
|
dataset = CIFAR10_Dataset(root=data_path,
|
||
|
|
normal_class=normal_class,
|
||
|
|
known_outlier_class=known_outlier_class,
|
||
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
||
|
|
ratio_known_normal=ratio_known_normal,
|
||
|
|
ratio_known_outlier=ratio_known_outlier,
|
||
|
|
ratio_pollution=ratio_pollution)
|
||
|
|
|
||
|
|
if dataset_name in ('arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid'):
|
||
|
|
dataset = ODDSADDataset(root=data_path,
|
||
|
|
dataset_name=dataset_name,
|
||
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
||
|
|
ratio_known_normal=ratio_known_normal,
|
||
|
|
ratio_known_outlier=ratio_known_outlier,
|
||
|
|
ratio_pollution=ratio_pollution,
|
||
|
|
random_state=random_state)
|
||
|
|
|
||
|
|
return dataset
|