initial work for elpv and subter datasets
elpv as example dataset/implementation subter with final dataset
This commit is contained in:
@@ -19,6 +19,8 @@ from datasets.main import load_dataset
|
||||
type=click.Choice(
|
||||
[
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -35,6 +37,8 @@ from datasets.main import load_dataset
|
||||
type=click.Choice(
|
||||
[
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"fmnist_LeNet",
|
||||
"cifar10_LeNet",
|
||||
"arrhythmia_mlp",
|
||||
@@ -109,7 +113,7 @@ from datasets.main import load_dataset
|
||||
@click.option(
|
||||
"--lr_milestone",
|
||||
type=int,
|
||||
default=0,
|
||||
default=[0],
|
||||
multiple=True,
|
||||
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
|
||||
)
|
||||
@@ -149,7 +153,7 @@ from datasets.main import load_dataset
|
||||
@click.option(
|
||||
"--ae_lr_milestone",
|
||||
type=int,
|
||||
default=0,
|
||||
default=[0],
|
||||
multiple=True,
|
||||
help="Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.",
|
||||
)
|
||||
@@ -393,9 +397,9 @@ def main(
|
||||
np.argsort(scores[labels == 0])
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10"):
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||
|
||||
if dataset_name in ("mnist", "fmnist"):
|
||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(1)
|
||||
X_normal_low = dataset.test_set.data[idx_normal_sorted[:32], ...].unsqueeze(
|
||||
|
||||
Reference in New Issue
Block a user