black formatted files before changes

This commit is contained in:
Jan Kowalczyk
2024-06-28 11:36:46 +02:00
parent d33c6b1e16
commit 71f9662022
40 changed files with 2938 additions and 1260 deletions

View File

@@ -10,15 +10,24 @@ class BaseADDataset(ABC):
self.root = root # root path to data
self.n_classes = 2 # 0: normal, 1: outlier
self.normal_classes = None # tuple with original class labels that define the normal class
self.outlier_classes = None # tuple with original class labels that define the outlier class
self.normal_classes = (
None # tuple with original class labels that define the normal class
)
self.outlier_classes = (
None # tuple with original class labels that define the outlier class
)
self.train_set = None # must be of type torch.utils.data.Dataset
self.test_set = None # must be of type torch.utils.data.Dataset
@abstractmethod
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
DataLoader, DataLoader):
def loaders(
self,
batch_size: int,
shuffle_train=True,
shuffle_test=False,
num_workers: int = 0,
) -> (DataLoader, DataLoader):
"""Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
pass

View File

@@ -22,5 +22,5 @@ class BaseNet(nn.Module):
"""Network summary."""
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in net_parameters])
self.logger.info('Trainable parameters: {}'.format(params))
self.logger.info("Trainable parameters: {}".format(params))
self.logger.info(self)

View File

@@ -6,8 +6,17 @@ from .base_net import BaseNet
class BaseTrainer(ABC):
"""Trainer base class."""
def __init__(self, optimizer_name: str, lr: float, n_epochs: int, lr_milestones: tuple, batch_size: int,
weight_decay: float, device: str, n_jobs_dataloader: int):
def __init__(
self,
optimizer_name: str,
lr: float,
n_epochs: int,
lr_milestones: tuple,
batch_size: int,
weight_decay: float,
device: str,
n_jobs_dataloader: int,
):
super().__init__()
self.optimizer_name = optimizer_name
self.lr = lr

View File

@@ -19,15 +19,22 @@ class ODDSDataset(Dataset):
"""
urls = {
'arrhythmia': 'https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1',
'cardio': 'https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1',
'satellite': 'https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1',
'satimage-2': 'https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1',
'shuttle': 'https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1',
'thyroid': 'https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1'
"arrhythmia": "https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1",
"cardio": "https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1",
"satellite": "https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1",
"satimage-2": "https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1",
"shuttle": "https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1",
"thyroid": "https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1",
}
def __init__(self, root: str, dataset_name: str, train=True, random_state=None, download=False):
def __init__(
self,
root: str,
dataset_name: str,
train=True,
random_state=None,
download=False,
):
super(Dataset, self).__init__()
self.classes = [0, 1]
@@ -37,25 +44,25 @@ class ODDSDataset(Dataset):
self.root = Path(root)
self.dataset_name = dataset_name
self.train = train # training set or test set
self.file_name = self.dataset_name + '.mat'
self.file_name = self.dataset_name + ".mat"
self.data_file = self.root / self.file_name
if download:
self.download()
mat = loadmat(self.data_file)
X = mat['X']
y = mat['y'].ravel()
X = mat["X"]
y = mat["y"].ravel()
idx_norm = y == 0
idx_out = y == 1
# 60% data for training and 40% for testing; keep outlier ratio
X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm],
test_size=0.4,
random_state=random_state)
X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out],
test_size=0.4,
random_state=random_state)
X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(
X[idx_norm], y[idx_norm], test_size=0.4, random_state=random_state
)
X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(
X[idx_out], y[idx_out], test_size=0.4, random_state=random_state
)
X_train = np.concatenate((X_train_norm, X_train_out))
X_test = np.concatenate((X_test_norm, X_test_out))
y_train = np.concatenate((y_train_norm, y_train_out))
@@ -88,7 +95,11 @@ class ODDSDataset(Dataset):
Returns:
tuple: (sample, target, semi_target, index)
"""
sample, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
sample, target, semi_target = (
self.data[index],
int(self.targets[index]),
int(self.semi_targets[index]),
)
return sample, target, semi_target, index
@@ -107,4 +118,4 @@ class ODDSDataset(Dataset):
# download file
download_url(self.urls[self.dataset_name], self.root, self.file_name)
print('Done!')
print("Done!")

View File

@@ -8,10 +8,25 @@ class TorchvisionDataset(BaseADDataset):
def __init__(self, root: str):
super().__init__(root)
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
DataLoader, DataLoader):
train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train,
num_workers=num_workers, drop_last=True)
test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test,
num_workers=num_workers, drop_last=False)
def loaders(
self,
batch_size: int,
shuffle_train=True,
shuffle_test=False,
num_workers: int = 0,
) -> (DataLoader, DataLoader):
train_loader = DataLoader(
dataset=self.train_set,
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
drop_last=True,
)
test_loader = DataLoader(
dataset=self.test_set,
batch_size=batch_size,
shuffle=shuffle_test,
num_workers=num_workers,
drop_last=False,
)
return train_loader, test_loader