full upload so not to lose anything important

This commit is contained in:
Jan Kowalczyk
2025-03-14 18:02:23 +01:00
parent 35fcfb7d5a
commit b824ff7482
33 changed files with 3539 additions and 353 deletions

View File

@@ -0,0 +1,264 @@
import logging
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Subset
from torchvision.datasets import VisionDataset
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import create_semisupervised_setting
class EsmeraSplit_Dataset(TorchvisionDataset):
def __init__(
self,
root: str,
ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
inference: bool = False,
):
super().__init__(root)
# Define normal and outlier classes
self.n_classes = 2 # 0: normal, 1: outlier
self.normal_classes = tuple([0])
self.outlier_classes = tuple([1])
self.inference_set = None
# MNIST preprocessing: feature scaling to [0, 1]
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
transform = transforms.ToTensor()
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
if inference:
self.inference_set = EsmeraSplitInference(
root=self.root,
transform=transform,
)
else:
# Get train set
train_set = EsmeraSplitTraining(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
)
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
np.array(semi_targets, dtype=np.int8)
) # set respective semi-supervised labels
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Get test set
self.test_set = EsmeraSplitTraining(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
)
def split_array_into_subarrays(array, split_height, split_width):
original_shape = array.shape
height, width = original_shape[-2], original_shape[-1]
assert height % split_height == 0, "The height is not divisible by the split_height"
assert width % split_width == 0, "The width is not divisible by the split_width"
num_splits_height = height // split_height
num_splits_width = width // split_width
reshaped_array = array.reshape(
-1, num_splits_height, split_height, num_splits_width, split_width
)
transposed_array = reshaped_array.transpose(0, 1, 3, 2, 4)
final_array = transposed_array.reshape(-1, split_height, split_width)
return final_array
class EsmeraSplitTraining(VisionDataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
train=False,
split=0.7,
seed=0,
height=16,
width=256,
):
super(EsmeraSplitTraining, self).__init__(
root, transforms, transform, target_transform
)
experiments_data = []
experiments_targets = []
validation_files = []
experiment_files = []
logger = logging.getLogger()
for experiment_file in Path(root).iterdir():
if experiment_file.is_dir() and experiment_file.name == "validation":
for validation_file in experiment_file.iterdir():
if validation_file.suffix != ".npy":
continue
validation_files.append(experiment_file)
if experiment_file.suffix != ".npy":
continue
experiment_files.append(experiment_file)
experiment_data = np.load(experiment_file)
if (
experiment_data.shape[1] % height != 0
or experiment_data.shape[2] % width != 0
):
logger.error(
f"Experiment {experiment_file.name} has shape {experiment_data.shape} which is not divisible by {height}x{width}"
)
experiment_data = split_array_into_subarrays(experiment_data, height, width)
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
experiment_targets = (
np.ones(experiment_data.shape[0], dtype=np.int8)
if "smoke" in experiment_file.name
else np.zeros(experiment_data.shape[0], dtype=np.int8)
)
experiments_data.append(experiment_data)
experiments_targets.append(experiment_targets)
filtered_validation_files = []
for validation_file in validation_files:
validation_file_name = validation_file.name
file_exists_in_experiments = any(
experiment_file.name == validation_file_name
for experiment_file in experiment_files
)
if not file_exists_in_experiments:
filtered_validation_files.append(validation_file)
validation_files = filtered_validation_files
logger.info(
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
)
logger.info(
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
)
lidar_projections = np.concatenate(experiments_data)
smoke_presence = np.concatenate(experiments_targets)
np.random.seed(seed)
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
shuffled_lidar_projections = lidar_projections[shuffled_indices]
shuffled_smoke_presence = smoke_presence[shuffled_indices]
split_idx = int(split * shuffled_lidar_projections.shape[0])
if train:
self.data = shuffled_lidar_projections[:split_idx]
self.targets = shuffled_smoke_presence[:split_idx]
else:
self.data = shuffled_lidar_projections[split_idx:]
self.targets = shuffled_smoke_presence[split_idx:]
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
self.targets = torch.tensor(self.targets, dtype=torch.int8)
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
tuple: (image, target, semi_target, index)
"""
img, target, semi_target = (
self.data[index],
int(self.targets[index]),
int(self.semi_targets[index]),
)
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="F")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, semi_target, index
class EsmeraSplitInference(VisionDataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
):
super(EsmeraSplitInference, self).__init__(root, transforms, transform)
logger = logging.getLogger()
self.experiment_file_path = Path(root)
if not self.experiment_file_path.is_file():
logger.error(
"For inference the data path has to be a single experiment file!"
)
raise Exception("Inference data is not a loadable file!")
self.data = np.load(self.experiment_file_path)
self.data = split_array_into_subarrays(self.data, 16, 256)
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
tuple: (image, index)
"""
img = self.data[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="F")
if self.transform is not None:
img = self.transform(img)
return img, index

View File

@@ -18,6 +18,9 @@ def load_dataset(
ratio_pollution: float = 0.0,
random_state=None,
inference: bool = False,
k_fold: bool = False,
num_known_normal: int = 0,
num_known_outlier: int = 0,
):
"""Loads the dataset."""
@@ -46,6 +49,9 @@ def load_dataset(
ratio_known_outlier=ratio_known_outlier,
ratio_pollution=ratio_pollution,
inference=inference,
k_fold=k_fold,
num_known_normal=num_known_normal,
num_known_outlier=num_known_outlier,
)
if dataset_name == "subtersplit":

View File

@@ -1,3 +1,4 @@
import json
import logging
import random
from pathlib import Path
@@ -6,12 +7,13 @@ from typing import Callable, Optional
import numpy as np
import torch
import torchvision.transforms as transforms
from base.torchvision_dataset import TorchvisionDataset
from PIL import Image
from torch.utils.data import Subset
from torch.utils.data.dataset import ConcatDataset
from torchvision.datasets import VisionDataset
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import create_semisupervised_setting
@@ -23,8 +25,22 @@ class SubTer_Dataset(TorchvisionDataset):
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
inference: bool = False,
k_fold: bool = False,
num_known_normal: int = 0,
num_known_outlier: int = 0,
only_use_given_semi_targets_for_evaluation: bool = True,
):
super().__init__(root)
if Path(root).is_dir():
with open(Path(root) / "semi_targets.json", "r") as f:
data = json.load(f)
semi_targets_given = {
item["filename"]: (
item["semi_target_begin_frame"],
item["semi_target_end_frame"],
)
for item in data["files"]
}
# Define normal and outlier classes
self.n_classes = 2 # 0: normal, 1: outlier
@@ -43,38 +59,146 @@ class SubTer_Dataset(TorchvisionDataset):
transform=transform,
)
else:
# Get train set
train_set = SubTerTraining(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
)
if k_fold:
# Get train set
data_set = SubTerTraining(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
split=1,
semi_targets_given=semi_targets_given,
)
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
np.array(semi_targets, dtype=np.int8)
) # set respective semi-supervised labels
np.random.seed(0)
semi_targets = data_set.semi_targets.numpy()
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Find indices where semi_targets is -1 (abnormal) or 1 (normal)
normal_indices = np.where(semi_targets == 1)[0]
abnormal_indices = np.where(semi_targets == -1)[0]
# Get test set
self.test_set = SubTerTraining(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
)
# Randomly select the specified number of indices to keep for each category
if len(normal_indices) > num_known_normal:
keep_normal_indices = np.random.choice(
normal_indices, size=num_known_normal, replace=False
)
else:
keep_normal_indices = (
normal_indices # Keep all if there are fewer than required
)
if len(abnormal_indices) > num_known_outlier:
keep_abnormal_indices = np.random.choice(
abnormal_indices, size=num_known_outlier, replace=False
)
else:
keep_abnormal_indices = (
abnormal_indices # Keep all if there are fewer than required
)
# Set all values to 0, then restore only the selected -1 and 1 values
semi_targets[(semi_targets == 1) | (semi_targets == -1)] = 0
semi_targets[keep_normal_indices] = 1
semi_targets[keep_abnormal_indices] = -1
data_set.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
self.data_set = data_set
# # Create semi-supervised setting
# idx, _, semi_targets = create_semisupervised_setting(
# data_set.targets.cpu().data.numpy(),
# self.normal_classes,
# self.outlier_classes,
# self.outlier_classes,
# ratio_known_normal,
# ratio_known_outlier,
# ratio_pollution,
# )
# data_set.semi_targets[idx] = torch.tensor(
# np.array(semi_targets, dtype=np.int8)
# ) # set respective semi-supervised labels
# # Subset data_set to semi-supervised setup
# self.data_set = Subset(data_set, idx)
else:
# Get train set
if only_use_given_semi_targets_for_evaluation:
pass
train_set = SubTerTrainingSelective(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
num_known_outlier=num_known_outlier,
semi_targets_given=semi_targets_given,
)
np.random.seed(0)
semi_targets = train_set.semi_targets.numpy()
# Find indices where semi_targets is -1 (abnormal) or 1 (normal)
normal_indices = np.where(semi_targets == 1)[0]
# Randomly select the specified number of indices to keep for each category
if len(normal_indices) > num_known_normal:
keep_normal_indices = np.random.choice(
normal_indices, size=num_known_normal, replace=False
)
else:
keep_normal_indices = (
normal_indices # Keep all if there are fewer than required
)
# Set all values to 0, then restore only the selected -1 and 1 values
semi_targets[semi_targets == 1] = 0
semi_targets[keep_normal_indices] = 1
train_set.semi_targets = torch.tensor(
semi_targets, dtype=torch.int8
)
self.train_set = train_set
self.test_set = SubTerTrainingSelective(
root=self.root,
transform=transform,
target_transform=target_transform,
num_known_outlier=num_known_outlier,
train=False,
semi_targets_given=semi_targets_given,
)
else:
train_set = SubTerTraining(
root=self.root,
transform=transform,
target_transform=target_transform,
train=True,
semi_targets_given=semi_targets_given,
)
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
np.array(semi_targets, dtype=np.int8)
) # set respective semi-supervised labels
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Get test set
self.test_set = SubTerTraining(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
semi_targets_given=semi_targets_given,
)
class SubTerTraining(VisionDataset):
@@ -87,6 +211,8 @@ class SubTerTraining(VisionDataset):
train=False,
split=0.7,
seed=0,
semi_targets_given=None,
only_use_given_semi_targets_for_evaluation=False,
):
super(SubTerTraining, self).__init__(
root, transforms, transform, target_transform
@@ -94,73 +220,120 @@ class SubTerTraining(VisionDataset):
experiments_data = []
experiments_targets = []
validation_files = []
experiments_semi_targets = []
# validation_files = []
experiment_files = []
experiment_frame_ids = []
experiment_file_ids = []
file_names = {}
for experiment_file in Path(root).iterdir():
if experiment_file.is_dir() and experiment_file.name == "validation":
for validation_file in experiment_file.iterdir():
if validation_file.suffix != ".npy":
continue
validation_files.append(experiment_file)
for file_idx, experiment_file in enumerate(sorted(Path(root).iterdir())):
# if experiment_file.is_dir() and experiment_file.name == "validation":
# for validation_file in experiment_file.iterdir():
# if validation_file.suffix != ".npy":
# continue
# validation_files.append(experiment_file)
if experiment_file.suffix != ".npy":
continue
file_names[file_idx] = experiment_file.name
experiment_files.append(experiment_file)
experiment_data = np.load(experiment_file)
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
experiment_targets = (
np.ones(experiment_data.shape[0], dtype=np.int8)
if "smoke" in experiment_file.name
else np.zeros(experiment_data.shape[0], dtype=np.int8)
)
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
experiment_semi_targets = np.zeros(experiment_data.shape[0], dtype=np.int8)
if "smoke" not in experiment_file.name:
experiment_semi_targets = np.ones(
experiment_data.shape[0], dtype=np.int8
)
else:
if semi_targets_given:
if experiment_file.name in semi_targets_given:
semi_target_begin_frame, semi_target_end_frame = (
semi_targets_given[experiment_file.name]
)
experiment_semi_targets[
semi_target_begin_frame:semi_target_end_frame
] = -1
else:
experiment_semi_targets = (
np.ones(experiment_data.shape[0], dtype=np.int8) * -1
)
experiment_file_ids.append(
np.full(experiment_data.shape[0], file_idx, dtype=np.int8)
)
experiment_frame_ids.append(
np.arange(experiment_data.shape[0], dtype=np.int32)
)
experiments_data.append(experiment_data)
experiments_targets.append(experiment_targets)
experiments_semi_targets.append(experiment_semi_targets)
filtered_validation_files = []
for validation_file in validation_files:
validation_file_name = validation_file.name
file_exists_in_experiments = any(
experiment_file.name == validation_file_name
for experiment_file in experiment_files
)
if not file_exists_in_experiments:
filtered_validation_files.append(validation_file)
validation_files = filtered_validation_files
# filtered_validation_files = []
# for validation_file in validation_files:
# validation_file_name = validation_file.name
# file_exists_in_experiments = any(
# experiment_file.name == validation_file_name
# for experiment_file in experiment_files
# )
# if not file_exists_in_experiments:
# filtered_validation_files.append(validation_file)
# validation_files = filtered_validation_files
logger = logging.getLogger()
logger.info(
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
)
logger.info(
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
)
# logger.info(
# f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
# )
lidar_projections = np.concatenate(experiments_data)
smoke_presence = np.concatenate(experiments_targets)
semi_targets = np.concatenate(experiments_semi_targets)
file_ids = np.concatenate(experiment_file_ids)
frame_ids = np.concatenate(experiment_frame_ids)
self.file_names = file_names
np.random.seed(seed)
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
shuffled_lidar_projections = lidar_projections[shuffled_indices]
shuffled_smoke_presence = smoke_presence[shuffled_indices]
shuffled_file_ids = file_ids[shuffled_indices]
shuffled_frame_ids = frame_ids[shuffled_indices]
shuffled_semis = semi_targets[shuffled_indices]
split_idx = int(split * shuffled_lidar_projections.shape[0])
if train:
self.data = shuffled_lidar_projections[:split_idx]
self.targets = shuffled_smoke_presence[:split_idx]
semi_targets = shuffled_semis[:split_idx]
self.shuffled_file_ids = shuffled_file_ids[:split_idx]
self.shuffled_frame_ids = shuffled_frame_ids[:split_idx]
else:
self.data = shuffled_lidar_projections[split_idx:]
self.targets = shuffled_smoke_presence[split_idx:]
semi_targets = shuffled_semis[split_idx:]
self.shuffled_file_ids = shuffled_file_ids[split_idx:]
self.shuffled_frame_ids = shuffled_frame_ids[split_idx:]
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
self.targets = torch.tensor(self.targets, dtype=torch.int8)
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
if semi_targets_given is not None:
self.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
else:
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
def __len__(self):
return len(self.data)
@@ -173,10 +346,12 @@ class SubTerTraining(VisionDataset):
Returns:
tuple: (image, target, semi_target, index)
"""
img, target, semi_target = (
img, target, semi_target, file_id, frame_id = (
self.data[index],
int(self.targets[index]),
int(self.semi_targets[index]),
int(self.shuffled_file_ids[index]),
int(self.shuffled_frame_ids[index]),
)
# doing this so that it is consistent with all other datasets
@@ -189,7 +364,10 @@ class SubTerTraining(VisionDataset):
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, semi_target, index
return img, target, semi_target, index, (file_id, frame_id)
def get_file_name_from_idx(self, idx: int):
return self.file_names[idx]
class SubTerInference(VisionDataset):
@@ -235,3 +413,191 @@ class SubTerInference(VisionDataset):
img = self.transform(img)
return img, index
class SubTerTrainingSelective(VisionDataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
train=False,
num_known_outlier=0,
seed=0,
semi_targets_given=None,
ratio_test_normal_to_anomalous=3,
):
super(SubTerTrainingSelective, self).__init__(
root, transforms, transform, target_transform
)
logger = logging.getLogger()
if semi_targets_given is None:
raise ValueError(
"semi_targets_given must be provided for selective training"
)
experiments_data = []
experiments_targets = []
experiments_semi_targets = []
# validation_files = []
experiment_files = []
experiment_frame_ids = []
experiment_file_ids = []
file_names = {}
for file_idx, experiment_file in enumerate(sorted(Path(root).iterdir())):
if experiment_file.suffix != ".npy":
continue
file_names[file_idx] = experiment_file.name
experiment_files.append(experiment_file)
experiment_data = np.load(experiment_file)
experiment_targets = (
np.ones(experiment_data.shape[0], dtype=np.int8)
if "smoke" in experiment_file.name
else np.zeros(experiment_data.shape[0], dtype=np.int8)
)
experiment_semi_targets = np.zeros(experiment_data.shape[0], dtype=np.int8)
if "smoke" not in experiment_file.name:
experiment_semi_targets = np.ones(
experiment_data.shape[0], dtype=np.int8
)
elif experiment_file.name in semi_targets_given:
semi_target_begin_frame, semi_target_end_frame = semi_targets_given[
experiment_file.name
]
experiment_semi_targets[
semi_target_begin_frame:semi_target_end_frame
] = -1
else:
raise ValueError(
"smoke experiment not in given semi_targets. required for selective training"
)
experiment_file_ids.append(
np.full(experiment_data.shape[0], file_idx, dtype=np.int8)
)
experiment_frame_ids.append(
np.arange(experiment_data.shape[0], dtype=np.int32)
)
experiments_data.append(experiment_data)
experiments_targets.append(experiment_targets)
experiments_semi_targets.append(experiment_semi_targets)
logger.info(
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
)
lidar_projections = np.concatenate(experiments_data)
smoke_presence = np.concatenate(experiments_targets)
semi_targets = np.concatenate(experiments_semi_targets)
file_ids = np.concatenate(experiment_file_ids)
frame_ids = np.concatenate(experiment_frame_ids)
self.file_names = file_names
np.random.seed(seed)
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
shuffled_lidar_projections = lidar_projections[shuffled_indices]
shuffled_smoke_presence = smoke_presence[shuffled_indices]
shuffled_file_ids = file_ids[shuffled_indices]
shuffled_frame_ids = frame_ids[shuffled_indices]
shuffled_semis = semi_targets[shuffled_indices]
# check if there are enough known normal and known outlier samples
outlier_indices = np.where(shuffled_semis == -1)[0]
normal_indices = np.where(shuffled_semis == 1)[0]
if len(outlier_indices) < num_known_outlier:
raise ValueError(
f"Not enough known outliers in dataset. Required: {num_known_outlier}, Found: {len(outlier_indices)}"
)
# randomly select known normal and outlier samples
keep_outlier_indices = np.random.choice(
outlier_indices, size=num_known_outlier, replace=False
)
# put outliers that are not kept into test set and the same number of normal samples aside for testing
test_outlier_indices = np.setdiff1d(outlier_indices, keep_outlier_indices)
num_test_outliers = len(test_outlier_indices)
test_normal_indices = np.random.choice(
normal_indices,
size=num_test_outliers * ratio_test_normal_to_anomalous,
replace=False,
)
# combine test indices
test_indices = np.concatenate([test_outlier_indices, test_normal_indices])
# training indices are the rest
train_indices = np.setdiff1d(np.arange(len(shuffled_semis)), test_indices)
if train:
self.data = shuffled_lidar_projections[train_indices]
self.targets = shuffled_smoke_presence[train_indices]
semi_targets = shuffled_semis[train_indices]
self.shuffled_file_ids = shuffled_file_ids[train_indices]
self.shuffled_frame_ids = shuffled_frame_ids[train_indices]
else:
self.data = shuffled_lidar_projections[test_indices]
self.targets = shuffled_smoke_presence[test_indices]
semi_targets = shuffled_semis[test_indices]
self.shuffled_file_ids = shuffled_file_ids[test_indices]
self.shuffled_frame_ids = shuffled_frame_ids[test_indices]
self.data = np.nan_to_num(self.data)
self.data = torch.tensor(self.data)
self.targets = torch.tensor(self.targets, dtype=torch.int8)
self.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
# log some stats to ensure the data is loaded correctly
if train:
logger.info(
f"Training set: {len(self.data)} samples, {sum(self.semi_targets == -1)} semi-supervised samples"
)
else:
logger.info(
f"Test set: {len(self.data)} samples, {sum(self.semi_targets == -1)} semi-supervised samples"
)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
tuple: (image, target, semi_target, index)
"""
img, target, semi_target, file_id, frame_id = (
self.data[index],
int(self.targets[index]),
int(self.semi_targets[index]),
int(self.shuffled_file_ids[index]),
int(self.shuffled_frame_ids[index]),
)
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="F")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, semi_target, index, (file_id, frame_id)
def get_file_name_from_idx(self, idx: int):
return self.file_names[idx]