from .cifar10 import CIFAR10_Dataset from .elpv import ELPV_Dataset from .fmnist import FashionMNIST_Dataset from .mnist import MNIST_Dataset from .odds import ODDSADDataset from .subter import SubTer_Dataset from .subtersplit import SubTerSplit_Dataset def load_dataset( dataset_name, data_path, normal_class, known_outlier_class, n_known_outlier_classes: int = 0, ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, random_state=None, inference: bool = False, k_fold: bool = False, num_known_normal: int = 0, num_known_outlier: int = 0, ): """Loads the dataset.""" implemented_datasets = ( "mnist", "elpv", "subter", "subtersplit", "fmnist", "cifar10", "arrhythmia", "cardio", "satellite", "satimage-2", "shuttle", "thyroid", ) assert dataset_name in implemented_datasets dataset = None if dataset_name == "subter": dataset = SubTer_Dataset( root=data_path, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, inference=inference, k_fold=k_fold, num_known_normal=num_known_normal, num_known_outlier=num_known_outlier, ) if dataset_name == "subtersplit": dataset = SubTerSplit_Dataset( root=data_path, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, inference=inference, ) if dataset_name == "elpv": dataset = ELPV_Dataset( root=data_path, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, ) if dataset_name == "mnist": dataset = MNIST_Dataset( root=data_path, normal_class=normal_class, known_outlier_class=known_outlier_class, n_known_outlier_classes=n_known_outlier_classes, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, ) if dataset_name == "fmnist": dataset = FashionMNIST_Dataset( root=data_path, normal_class=normal_class, known_outlier_class=known_outlier_class, n_known_outlier_classes=n_known_outlier_classes, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, ) if dataset_name == "cifar10": dataset = CIFAR10_Dataset( root=data_path, normal_class=normal_class, known_outlier_class=known_outlier_class, n_known_outlier_classes=n_known_outlier_classes, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, ) if dataset_name in ( "arrhythmia", "cardio", "satellite", "satimage-2", "shuttle", "thyroid", ): dataset = ODDSADDataset( root=data_path, dataset_name=dataset_name, n_known_outlier_classes=n_known_outlier_classes, ratio_known_normal=ratio_known_normal, ratio_known_outlier=ratio_known_outlier, ratio_pollution=ratio_pollution, random_state=random_state, ) return dataset