Files
mt/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py

219 lines
6.9 KiB
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
from base.base_trainer import BaseTrainer
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import roc_auc_score
import logging
import time
import torch
import torch.optim as optim
import numpy as np
class DeepSADTrainer(BaseTrainer):
2024-06-28 11:36:46 +02:00
def __init__(
self,
c,
eta: float,
optimizer_name: str = "adam",
lr: float = 0.001,
n_epochs: int = 150,
lr_milestones: tuple = (),
batch_size: int = 128,
weight_decay: float = 1e-6,
device: str = "cuda",
n_jobs_dataloader: int = 0,
):
super().__init__(
optimizer_name,
lr,
n_epochs,
lr_milestones,
batch_size,
weight_decay,
device,
n_jobs_dataloader,
)
2024-06-28 07:42:12 +02:00
# Deep SAD parameters
self.c = torch.tensor(c, device=self.device) if c is not None else None
self.eta = eta
# Optimization parameters
self.eps = 1e-6
# Results
self.train_time = None
self.test_auc = None
self.test_time = None
self.test_scores = None
def train(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get train data loader
2024-06-28 11:36:46 +02:00
train_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
2024-06-28 07:42:12 +02:00
# Set device for network
net = net.to(self.device)
# Set optimizer (Adam optimizer for now)
2024-06-28 11:36:46 +02:00
optimizer = optim.Adam(
net.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
2024-06-28 07:42:12 +02:00
# Set learning rate scheduler
2024-06-28 11:36:46 +02:00
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, milestones=self.lr_milestones, gamma=0.1
)
2024-06-28 07:42:12 +02:00
# Initialize hypersphere center c (if c not loaded)
if self.c is None:
2024-06-28 11:36:46 +02:00
logger.info("Initializing center c...")
2024-06-28 07:42:12 +02:00
self.c = self.init_center_c(train_loader, net)
2024-06-28 11:36:46 +02:00
logger.info("Center c initialized.")
2024-06-28 07:42:12 +02:00
# Training
2024-06-28 11:36:46 +02:00
logger.info("Starting training...")
2024-06-28 07:42:12 +02:00
start_time = time.time()
net.train()
for epoch in range(self.n_epochs):
scheduler.step()
if epoch in self.lr_milestones:
2024-06-28 11:36:46 +02:00
logger.info(
" LR scheduler: new learning rate is %g"
% float(scheduler.get_lr()[0])
)
2024-06-28 07:42:12 +02:00
epoch_loss = 0.0
n_batches = 0
epoch_start_time = time.time()
for data in train_loader:
inputs, _, semi_targets, _ = data
2024-06-28 11:36:46 +02:00
inputs, semi_targets = inputs.to(self.device), semi_targets.to(
self.device
)
2024-06-28 07:42:12 +02:00
# Zero the network parameter gradients
optimizer.zero_grad()
# Update network parameters via backpropagation: forward + backward + optimize
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
2024-06-28 11:36:46 +02:00
losses = torch.where(
semi_targets == 0,
dist,
self.eta * ((dist + self.eps) ** semi_targets.float()),
)
2024-06-28 07:42:12 +02:00
loss = torch.mean(losses)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
# log epoch statistics
epoch_train_time = time.time() - epoch_start_time
2024-06-28 11:36:46 +02:00
logger.info(
f"| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s "
f"| Train Loss: {epoch_loss / n_batches:.6f} |"
)
2024-06-28 07:42:12 +02:00
self.train_time = time.time() - start_time
2024-06-28 11:36:46 +02:00
logger.info("Training Time: {:.3f}s".format(self.train_time))
logger.info("Finished training.")
2024-06-28 07:42:12 +02:00
return net
def test(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get test data loader
2024-06-28 11:36:46 +02:00
_, test_loader = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
2024-06-28 07:42:12 +02:00
# Set device for network
net = net.to(self.device)
# Testing
2024-06-28 11:36:46 +02:00
logger.info("Starting testing...")
2024-06-28 07:42:12 +02:00
epoch_loss = 0.0
n_batches = 0
start_time = time.time()
idx_label_score = []
net.eval()
with torch.no_grad():
for data in test_loader:
inputs, labels, semi_targets, idx = data
inputs = inputs.to(self.device)
labels = labels.to(self.device)
semi_targets = semi_targets.to(self.device)
idx = idx.to(self.device)
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
2024-06-28 11:36:46 +02:00
losses = torch.where(
semi_targets == 0,
dist,
self.eta * ((dist + self.eps) ** semi_targets.float()),
)
2024-06-28 07:42:12 +02:00
loss = torch.mean(losses)
scores = dist
# Save triples of (idx, label, score) in a list
2024-06-28 11:36:46 +02:00
idx_label_score += list(
zip(
idx.cpu().data.numpy().tolist(),
labels.cpu().data.numpy().tolist(),
scores.cpu().data.numpy().tolist(),
)
)
2024-06-28 07:42:12 +02:00
epoch_loss += loss.item()
n_batches += 1
self.test_time = time.time() - start_time
self.test_scores = idx_label_score
# Compute AUC
_, labels, scores = zip(*idx_label_score)
labels = np.array(labels)
scores = np.array(scores)
self.test_auc = roc_auc_score(labels, scores)
# Log results
2024-06-28 11:36:46 +02:00
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
logger.info("Test AUC: {:.2f}%".format(100.0 * self.test_auc))
logger.info("Test Time: {:.3f}s".format(self.test_time))
logger.info("Finished testing.")
2024-06-28 07:42:12 +02:00
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
"""Initialize hypersphere center c as the mean from an initial forward pass on the data."""
n_samples = 0
c = torch.zeros(net.rep_dim, device=self.device)
net.eval()
with torch.no_grad():
for data in train_loader:
# get the inputs of the batch
inputs, _, _, _ = data
inputs = inputs.to(self.device)
outputs = net(inputs)
n_samples += outputs.shape[0]
c += torch.sum(outputs, dim=0)
c /= n_samples
# If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
c[(abs(c) < eps) & (c < 0)] = -eps
c[(abs(c) < eps) & (c > 0)] = eps
return c