initial work for elpv and subter datasets

elpv as example dataset/implementation
subter with final dataset
This commit is contained in:
Jan Kowalczyk
2024-06-28 11:40:19 +02:00
parent 71f9662022
commit d6a019a8bb
13 changed files with 1585 additions and 4 deletions

View File

@@ -19,6 +19,8 @@ from datasets.main import load_dataset
type=click.Choice(
[
"mnist",
"elpv",
"subter",
"fmnist",
"cifar10",
"arrhythmia",
@@ -35,6 +37,8 @@ from datasets.main import load_dataset
type=click.Choice(
[
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"fmnist_LeNet",
"cifar10_LeNet",
"arrhythmia_mlp",
@@ -109,7 +113,7 @@ from datasets.main import load_dataset
@click.option(
"--lr_milestone",
type=int,
default=0,
default=[0],
multiple=True,
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
)
@@ -149,7 +153,7 @@ from datasets.main import load_dataset
@click.option(
"--ae_lr_milestone",
type=int,
default=0,
default=[0],
multiple=True,
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
)
@@ -393,9 +397,9 @@ def main(
np.argsort(scores[labels == 0])
] # from lowest to highest score
if dataset_name in ("mnist", "fmnist", "cifar10"):
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
if dataset_name in ("mnist", "fmnist"):
if dataset_name in ("mnist", "fmnist", "elpv"):
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1)
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(