black formatted files before changes

This commit is contained in:
Jan Kowalczyk
2024-06-28 11:36:46 +02:00
parent d33c6b1e16
commit 71f9662022
40 changed files with 2938 additions and 1260 deletions

View File

@@ -11,8 +11,16 @@ import random
class FashionMNIST_Dataset(TorchvisionDataset):
def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0,
ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0):
def __init__(
self,
root: str,
normal_class: int = 0,
known_outlier_class: int = 1,
n_known_outlier_classes: int = 0,
ratio_known_normal: float = 0.0,
ratio_known_outlier: float = 0.0,
ratio_pollution: float = 0.0,
):
super().__init__(root)
# Define normal and outlier classes
@@ -27,28 +35,48 @@ class FashionMNIST_Dataset(TorchvisionDataset):
elif n_known_outlier_classes == 1:
self.known_outlier_classes = tuple([known_outlier_class])
else:
self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))
self.known_outlier_classes = tuple(
random.sample(self.outlier_classes, n_known_outlier_classes)
)
# FashionMNIST preprocessing: feature scaling to [0, 1]
transform = transforms.ToTensor()
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
# Get train set
train_set = MyFashionMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform,
download=True)
train_set = MyFashionMNIST(
root=self.root,
train=True,
transform=transform,
target_transform=target_transform,
download=True,
)
# Create semi-supervised setting
idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes,
self.outlier_classes, self.known_outlier_classes,
ratio_known_normal, ratio_known_outlier, ratio_pollution)
train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels
idx, _, semi_targets = create_semisupervised_setting(
train_set.targets.cpu().data.numpy(),
self.normal_classes,
self.outlier_classes,
self.known_outlier_classes,
ratio_known_normal,
ratio_known_outlier,
ratio_pollution,
)
train_set.semi_targets[idx] = torch.tensor(
semi_targets
) # set respective semi-supervised labels
# Subset train_set to semi-supervised setup
self.train_set = Subset(train_set, idx)
# Get test set
self.test_set = MyFashionMNIST(root=self.root, train=False, transform=transform,
target_transform=target_transform, download=True)
self.test_set = MyFashionMNIST(
root=self.root,
train=False,
transform=transform,
target_transform=target_transform,
download=True,
)
class MyFashionMNIST(FashionMNIST):
@@ -70,11 +98,15 @@ class MyFashionMNIST(FashionMNIST):
Returns:
tuple: (image, target, semi_target, index)
"""
img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
img, target, semi_target = (
self.data[index],
int(self.targets[index]),
int(self.semi_targets[index]),
)
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)