implemented inference
This commit is contained in:
@@ -16,6 +16,7 @@ def load_dataset(
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
random_state=None,
|
||||
inference: bool = False,
|
||||
):
|
||||
"""Loads the dataset."""
|
||||
|
||||
@@ -42,6 +43,7 @@ def load_dataset(
|
||||
ratio_known_normal=ratio_known_normal,
|
||||
ratio_known_outlier=ratio_known_outlier,
|
||||
ratio_pollution=ratio_pollution,
|
||||
inference=inference,
|
||||
)
|
||||
|
||||
if dataset_name == "elpv":
|
||||
|
||||
@@ -6,6 +6,7 @@ from base.torchvision_dataset import TorchvisionDataset
|
||||
from .preprocessing import create_semisupervised_setting
|
||||
from typing import Callable, Optional
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
@@ -22,6 +23,7 @@ class SubTer_Dataset(TorchvisionDataset):
|
||||
ratio_known_normal: float = 0.0,
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
inference: bool = False,
|
||||
):
|
||||
super().__init__(root)
|
||||
|
||||
@@ -35,41 +37,47 @@ class SubTer_Dataset(TorchvisionDataset):
|
||||
transform = transforms.ToTensor()
|
||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||
|
||||
# Get train set
|
||||
train_set = MySubTer(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
train=True,
|
||||
)
|
||||
if inference:
|
||||
self.inference_set = SubTerInference(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
)
|
||||
else:
|
||||
# Get train set
|
||||
train_set = SubTerTraining(
|
||||
root=self.root,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
train_set.targets.cpu().data.numpy(),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
np.array(semi_targets, dtype=np.int8)
|
||||
) # set respective semi-supervised labels
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
train_set.targets.cpu().data.numpy(),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
np.array(semi_targets, dtype=np.int8)
|
||||
) # set respective semi-supervised labels
|
||||
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
|
||||
# Get test set
|
||||
self.test_set = MySubTer(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
# Get test set
|
||||
self.test_set = SubTerTraining(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
|
||||
|
||||
class MySubTer(VisionDataset):
|
||||
class SubTerTraining(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -81,7 +89,9 @@ class MySubTer(VisionDataset):
|
||||
split=0.7,
|
||||
seed=0,
|
||||
):
|
||||
super(MySubTer, self).__init__(root, transforms, transform, target_transform)
|
||||
super(SubTerTraining, self).__init__(
|
||||
root, transforms, transform, target_transform
|
||||
)
|
||||
|
||||
experiments_data = []
|
||||
experiments_targets = []
|
||||
@@ -153,3 +163,49 @@ class MySubTer(VisionDataset):
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target, semi_target, index
|
||||
|
||||
|
||||
class SubTerInference(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
):
|
||||
super(SubTerInference, self).__init__(root, transforms, transform)
|
||||
logger = logging.getLogger()
|
||||
|
||||
self.experiment_file_path = Path(root)
|
||||
|
||||
if not self.experiment_file_path.is_file():
|
||||
logger.error(
|
||||
"For inference the data path has to be a single experiment file!"
|
||||
)
|
||||
raise Exception("Inference data is not a loadable file!")
|
||||
|
||||
self.data = np.load(self.experiment_file_path)
|
||||
self.data = np.nan_to_num(self.data)
|
||||
self.data = torch.tensor(self.data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Override the original method of the MNIST class.
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, index)
|
||||
"""
|
||||
img = self.data[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="F")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, index
|
||||
|
||||
Reference in New Issue
Block a user