split subter implementation (training + inference)
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
import click
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets.main import load_dataset
|
||||
from DeepSAD import DeepSAD
|
||||
from utils.config import Config
|
||||
from utils.visualization.plot_images_grid import plot_images_grid
|
||||
from DeepSAD import DeepSAD
|
||||
from datasets.main import load_dataset
|
||||
|
||||
|
||||
################################################################################
|
||||
@@ -31,6 +31,7 @@ from datasets.main import load_dataset
|
||||
"mnist",
|
||||
"elpv",
|
||||
"subter",
|
||||
"subtersplit",
|
||||
"fmnist",
|
||||
"cifar10",
|
||||
"arrhythmia",
|
||||
@@ -49,6 +50,7 @@ from datasets.main import load_dataset
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_LeNet_Split",
|
||||
"fmnist_LeNet",
|
||||
"cifar10_LeNet",
|
||||
"arrhythmia_mlp",
|
||||
@@ -315,7 +317,6 @@ def main(
|
||||
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
||||
|
||||
if action == "train":
|
||||
|
||||
# Load data
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
@@ -413,7 +414,6 @@ def main(
|
||||
] # from lowest to highest score
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||
|
||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
||||
|
||||
Reference in New Issue
Block a user