split subter implementation (training + inference)

This commit is contained in:
Jan Kowalczyk
2024-10-17 08:36:18 +02:00
parent 5014c41b24
commit ddf4e4aa36
11 changed files with 699 additions and 145 deletions

View File

@@ -1,14 +1,14 @@
import click
import torch
import logging
import random
import numpy as np
from pathlib import Path
import click
import numpy as np
import torch
from datasets.main import load_dataset
from DeepSAD import DeepSAD
from utils.config import Config
from utils.visualization.plot_images_grid import plot_images_grid
from DeepSAD import DeepSAD
from datasets.main import load_dataset
################################################################################
@@ -31,6 +31,7 @@ from datasets.main import load_dataset
"mnist",
"elpv",
"subter",
"subtersplit",
"fmnist",
"cifar10",
"arrhythmia",
@@ -49,6 +50,7 @@ from datasets.main import load_dataset
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"subter_LeNet_Split",
"fmnist_LeNet",
"cifar10_LeNet",
"arrhythmia_mlp",
@@ -315,7 +317,6 @@ def main(
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
if action == "train":
# Load data
dataset = load_dataset(
dataset_name,
@@ -413,7 +414,6 @@ def main(
] # from lowest to highest score
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
if dataset_name in ("mnist", "fmnist", "elpv"):
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(