Files
mt/Deep-SAD-PyTorch/src/datasets/main.py

126 lines
3.6 KiB
Python
Raw Normal View History

from .cifar10 import CIFAR10_Dataset
from .elpv import ELPV_Dataset
2024-06-28 07:42:12 +02:00
from .fmnist import FashionMNIST_Dataset
from .mnist import MNIST_Dataset
2024-06-28 07:42:12 +02:00
from .odds import ODDSADDataset
from .subter import SubTer_Dataset
from .subtersplit import SubTerSplit_Dataset
2024-06-28 07:42:12 +02:00
2024-06-28 11:36:46 +02:00
def load_dataset(
dataset_name,
data_path,
normal_class,
known_outlier_class,
n_known_outlier_classes: int = 0,
ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
random_state=None,
2024-07-04 15:36:01 +02:00
inference: bool = False,
k_fold: bool = False,
num_known_normal: int = 0,
num_known_outlier: int = 0,
2024-06-28 11:36:46 +02:00
):
2024-06-28 07:42:12 +02:00
"""Loads the dataset."""
2024-06-28 11:36:46 +02:00
implemented_datasets = (
"mnist",
"elpv",
"subter",
"subtersplit",
2024-06-28 11:36:46 +02:00
"fmnist",
"cifar10",
"arrhythmia",
"cardio",
"satellite",
"satimage-2",
"shuttle",
"thyroid",
)
2024-06-28 07:42:12 +02:00
assert dataset_name in implemented_datasets
dataset = None
if dataset_name == "subter":
dataset = SubTer_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
2024-07-04 15:36:01 +02:00
inference=inference,
k_fold=k_fold,
num_known_normal=num_known_normal,
num_known_outlier=num_known_outlier,
)
if dataset_name == "subtersplit":
dataset = SubTerSplit_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
inference=inference,
)
if dataset_name == "elpv":
dataset = ELPV_Dataset(
root=data_path,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
2024-06-28 11:36:46 +02:00
if dataset_name == "mnist":
dataset = MNIST_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name == "fmnist":
dataset = FashionMNIST_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name == "cifar10":
dataset = CIFAR10_Dataset(
root=data_path,
normal_class=normal_class,
known_outlier_class=known_outlier_class,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
)
if dataset_name in (
"arrhythmia",
"cardio",
"satellite",
"satimage-2",
"shuttle",
"thyroid",
):
dataset = ODDSADDataset(
root=data_path,
dataset_name=dataset_name,
n_known_outlier_classes=n_known_outlier_classes,
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
random_state=random_state,
)
2024-06-28 07:42:12 +02:00
return dataset