2024-06-28 07:42:12 +02:00
|
|
|
from .mnist import MNIST_Dataset
|
2024-06-28 11:40:19 +02:00
|
|
|
from .elpv import ELPV_Dataset
|
|
|
|
|
from .subter import SubTer_Dataset
|
2024-06-28 07:42:12 +02:00
|
|
|
from .fmnist import FashionMNIST_Dataset
|
|
|
|
|
from .cifar10 import CIFAR10_Dataset
|
|
|
|
|
from .odds import ODDSADDataset
|
|
|
|
|
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
def load_dataset(
|
|
|
|
|
dataset_name,
|
|
|
|
|
data_path,
|
|
|
|
|
normal_class,
|
|
|
|
|
known_outlier_class,
|
|
|
|
|
n_known_outlier_classes: int = 0,
|
|
|
|
|
ratio_known_normal: float = 0.0,
|
|
|
|
|
ratio_known_outlier: float = 0.0,
|
|
|
|
|
ratio_pollution: float = 0.0,
|
|
|
|
|
random_state=None,
|
2024-07-04 15:36:01 +02:00
|
|
|
inference: bool = False,
|
2024-06-28 11:36:46 +02:00
|
|
|
):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Loads the dataset."""
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
implemented_datasets = (
|
|
|
|
|
"mnist",
|
2024-06-28 11:40:19 +02:00
|
|
|
"elpv",
|
|
|
|
|
"subter",
|
2024-06-28 11:36:46 +02:00
|
|
|
"fmnist",
|
|
|
|
|
"cifar10",
|
|
|
|
|
"arrhythmia",
|
|
|
|
|
"cardio",
|
|
|
|
|
"satellite",
|
|
|
|
|
"satimage-2",
|
|
|
|
|
"shuttle",
|
|
|
|
|
"thyroid",
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
assert dataset_name in implemented_datasets
|
|
|
|
|
|
|
|
|
|
dataset = None
|
|
|
|
|
|
2024-06-28 11:40:19 +02:00
|
|
|
if dataset_name == "subter":
|
|
|
|
|
dataset = SubTer_Dataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
2024-07-04 15:36:01 +02:00
|
|
|
inference=inference,
|
2024-06-28 11:40:19 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if dataset_name == "elpv":
|
|
|
|
|
dataset = ELPV_Dataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
if dataset_name == "mnist":
|
|
|
|
|
dataset = MNIST_Dataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
normal_class=normal_class,
|
|
|
|
|
known_outlier_class=known_outlier_class,
|
|
|
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if dataset_name == "fmnist":
|
|
|
|
|
dataset = FashionMNIST_Dataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
normal_class=normal_class,
|
|
|
|
|
known_outlier_class=known_outlier_class,
|
|
|
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if dataset_name == "cifar10":
|
|
|
|
|
dataset = CIFAR10_Dataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
normal_class=normal_class,
|
|
|
|
|
known_outlier_class=known_outlier_class,
|
|
|
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if dataset_name in (
|
|
|
|
|
"arrhythmia",
|
|
|
|
|
"cardio",
|
|
|
|
|
"satellite",
|
|
|
|
|
"satimage-2",
|
|
|
|
|
"shuttle",
|
|
|
|
|
"thyroid",
|
|
|
|
|
):
|
|
|
|
|
dataset = ODDSADDataset(
|
|
|
|
|
root=data_path,
|
|
|
|
|
dataset_name=dataset_name,
|
|
|
|
|
n_known_outlier_classes=n_known_outlier_classes,
|
|
|
|
|
ratio_known_normal=ratio_known_normal,
|
|
|
|
|
ratio_known_outlier=ratio_known_outlier,
|
|
|
|
|
ratio_pollution=ratio_pollution,
|
|
|
|
|
random_state=random_state,
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
return dataset
|