implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -16,6 +16,7 @@ def load_dataset(
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
random_state=None,
inference: bool = False,
):
"""Loads the dataset."""
@@ -42,6 +43,7 @@ def load_dataset(
ratio_known_normal=ratio_known_normal,
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
inference=inference,
)
if dataset_name == "elpv":

View File

@@ -6,6 +6,7 @@ from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import create_semisupervised_setting
from typing import Callable, Optional
import logging
import torch
import torchvision.transforms as transforms
import random
@@ -22,6 +23,7 @@ class SubTer_Dataset(TorchvisionDataset):
ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
inference: bool = False,
):
super().__init__(root)
@@ -35,41 +37,47 @@ class SubTer_Dataset(TorchvisionDataset):
transform = transforms.ToTensor()
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
# Get train set
train_set = MySubTer(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
)
if inference:
self.inference_set = SubTerInference(
root=self.root,
transform=transform,
)
else:
# Get train set
train_set = SubTerTraining(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
)
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
np.array(semi_targets, dtype=np.int8)
) # set respective semi-supervised labels
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
np.array(semi_targets, dtype=np.int8)
) # set respective semi-supervised labels
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Get test set
self.test_set = MySubTer(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
)
# Get test set
self.test_set = SubTerTraining(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
)
class MySubTer(VisionDataset):
class SubTerTraining(VisionDataset):
def __init__(
self,
@@ -81,7 +89,9 @@ class MySubTer(VisionDataset):
split=0.7,
seed=0,
):
super(MySubTer, self).__init__(root, transforms, transform, target_transform)
super(SubTerTraining, self).__init__(
root, transforms, transform, target_transform
)
experiments_data = []
experiments_targets = []
@@ -153,3 +163,49 @@ class MySubTer(VisionDataset):
target = self.target_transform(target)
return img, target, semi_target, index
class SubTerInference(VisionDataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
):
super(SubTerInference, self).__init__(root, transforms, transform)
logger = logging.getLogger()
self.experiment_file_path = Path(root)
if not self.experiment_file_path.is_file():
logger.error(
"For inference the data path has to be a single experiment file!"
)
raise Exception("Inference data is not a loadable file!")
self.data = np.load(self.experiment_file_path)
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
tuple: (image, index)
"""
img = self.data[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="F")
if self.transform is not None:
img = self.transform(img)
return img, index