Files
mt/Deep-SAD-PyTorch/src/datasets/main.py
2024-10-17 08:36:18 +02:00

120 lines
3.4 KiB
Python

from .cifar10 import CIFAR10_Dataset
from .elpv import ELPV_Dataset
from .fmnist import FashionMNIST_Dataset
from .mnist import MNIST_Dataset
from .odds import ODDSADDataset
from .subter import SubTer_Dataset
from .subtersplit import SubTerSplit_Dataset
def load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes: int = 0,
ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
random_state=None,
inference: bool = False,
):
"""Loads the dataset."""
implemented_datasets = (
"mnist",
"elpv",
"subter",
"subtersplit",
"fmnist",
"cifar10",
"arrhythmia",
"cardio",
"satellite",
"satimage-2",
"shuttle",
"thyroid",
)
assert dataset_name in implemented_datasets
dataset = None
if dataset_name == "subter":
dataset = SubTer_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
inference=inference,
)
if dataset_name == "subtersplit":
dataset = SubTerSplit_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
inference=inference,
)
if dataset_name == "elpv":
dataset = ELPV_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name == "mnist":
dataset = MNIST_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name == "fmnist":
dataset = FashionMNIST_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name == "cifar10":
dataset = CIFAR10_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name in (
"arrhythmia",
"cardio",
"satellite",
"satimage-2",
"shuttle",
"thyroid",
):
dataset = ODDSADDataset(
root=data_path,
dataset_name=dataset_name,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
random_state=random_state,
)
return dataset