2024-06-28 07:42:12 +02:00
|
|
|
import logging
|
|
|
|
|
import time
|
2025-03-14 18:02:23 +01:00
|
|
|
|
|
|
|
|
import numpy as np
|
2024-06-28 07:42:12 +02:00
|
|
|
import torch
|
|
|
|
|
import torch.optim as optim
|
2025-03-14 18:02:23 +01:00
|
|
|
from sklearn.metrics import (
|
|
|
|
|
average_precision_score,
|
|
|
|
|
precision_recall_curve,
|
|
|
|
|
roc_auc_score,
|
|
|
|
|
roc_curve,
|
|
|
|
|
)
|
|
|
|
|
from torch.utils.data.dataloader import DataLoader
|
|
|
|
|
|
|
|
|
|
from base.base_dataset import BaseADDataset
|
|
|
|
|
from base.base_net import BaseNet
|
|
|
|
|
from base.base_trainer import BaseTrainer
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSADTrainer(BaseTrainer):
|
2024-06-28 11:36:46 +02:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
c,
|
|
|
|
|
eta: float,
|
|
|
|
|
optimizer_name: str = "adam",
|
|
|
|
|
lr: float = 0.001,
|
|
|
|
|
n_epochs: int = 150,
|
|
|
|
|
lr_milestones: tuple = (),
|
|
|
|
|
batch_size: int = 128,
|
|
|
|
|
weight_decay: float = 1e-6,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
n_jobs_dataloader: int = 0,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
optimizer_name,
|
|
|
|
|
lr,
|
|
|
|
|
n_epochs,
|
|
|
|
|
lr_milestones,
|
|
|
|
|
batch_size,
|
|
|
|
|
weight_decay,
|
|
|
|
|
device,
|
|
|
|
|
n_jobs_dataloader,
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Deep SAD parameters
|
|
|
|
|
self.c = torch.tensor(c, device=self.device) if c is not None else None
|
|
|
|
|
self.eta = eta
|
|
|
|
|
|
|
|
|
|
# Optimization parameters
|
|
|
|
|
self.eps = 1e-6
|
|
|
|
|
|
|
|
|
|
# Results
|
|
|
|
|
self.train_time = None
|
|
|
|
|
self.test_auc = None
|
|
|
|
|
self.test_time = None
|
|
|
|
|
self.test_scores = None
|
|
|
|
|
|
2025-03-14 18:02:23 +01:00
|
|
|
def train(
|
|
|
|
|
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
|
|
|
|
|
) -> BaseNet:
|
2024-06-28 07:42:12 +02:00
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
# Get train data loader
|
2025-03-14 18:02:23 +01:00
|
|
|
if k_fold_idx is not None:
|
|
|
|
|
train_loader, _ = dataset.loaders_k_fold(
|
|
|
|
|
fold_idx=k_fold_idx,
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
num_workers=self.n_jobs_dataloader,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
train_loader, _, _ = dataset.loaders(
|
|
|
|
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Set device for network
|
|
|
|
|
net = net.to(self.device)
|
|
|
|
|
|
|
|
|
|
# Set optimizer (Adam optimizer for now)
|
2024-06-28 11:36:46 +02:00
|
|
|
optimizer = optim.Adam(
|
|
|
|
|
net.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Set learning rate scheduler
|
2024-06-28 11:36:46 +02:00
|
|
|
scheduler = optim.lr_scheduler.MultiStepLR(
|
|
|
|
|
optimizer, milestones=self.lr_milestones, gamma=0.1
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Initialize hypersphere center c (if c not loaded)
|
|
|
|
|
if self.c is None:
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Initializing center c...")
|
2024-06-28 07:42:12 +02:00
|
|
|
self.c = self.init_center_c(train_loader, net)
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Center c initialized.")
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Training
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Starting training...")
|
2024-06-28 07:42:12 +02:00
|
|
|
start_time = time.time()
|
|
|
|
|
net.train()
|
|
|
|
|
for epoch in range(self.n_epochs):
|
|
|
|
|
epoch_loss = 0.0
|
|
|
|
|
n_batches = 0
|
|
|
|
|
epoch_start_time = time.time()
|
|
|
|
|
for data in train_loader:
|
2025-03-14 18:02:23 +01:00
|
|
|
inputs, _, semi_targets, _, _ = data
|
|
|
|
|
inputs, semi_targets = (
|
|
|
|
|
inputs.to(self.device),
|
|
|
|
|
semi_targets.to(self.device),
|
2024-06-28 11:36:46 +02:00
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Zero the network parameter gradients
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
# Update network parameters via backpropagation: forward + backward + optimize
|
|
|
|
|
outputs = net(inputs)
|
|
|
|
|
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
2024-06-28 11:36:46 +02:00
|
|
|
losses = torch.where(
|
|
|
|
|
semi_targets == 0,
|
|
|
|
|
dist,
|
|
|
|
|
self.eta * ((dist + self.eps) ** semi_targets.float()),
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
loss = torch.mean(losses)
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
epoch_loss += loss.item()
|
|
|
|
|
n_batches += 1
|
|
|
|
|
|
2024-06-28 12:00:37 +02:00
|
|
|
scheduler.step()
|
|
|
|
|
if epoch in self.lr_milestones:
|
|
|
|
|
logger.info(
|
|
|
|
|
" LR scheduler: new learning rate is %g"
|
2024-07-04 07:39:37 +02:00
|
|
|
% float(scheduler.get_last_lr()[0])
|
2024-06-28 12:00:37 +02:00
|
|
|
)
|
|
|
|
|
|
2024-06-28 07:42:12 +02:00
|
|
|
# log epoch statistics
|
|
|
|
|
epoch_train_time = time.time() - epoch_start_time
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info(
|
|
|
|
|
f"| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s "
|
|
|
|
|
f"| Train Loss: {epoch_loss / n_batches:.6f} |"
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
self.train_time = time.time() - start_time
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Training Time: {:.3f}s".format(self.train_time))
|
|
|
|
|
logger.info("Finished training.")
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
return net
|
|
|
|
|
|
2024-07-04 15:36:01 +02:00
|
|
|
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
# Get test data loader
|
|
|
|
|
_, _, inference_loader = dataset.loaders(
|
|
|
|
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Set device for network
|
|
|
|
|
net = net.to(self.device)
|
|
|
|
|
|
|
|
|
|
# Testing
|
|
|
|
|
logger.info("Starting inference...")
|
|
|
|
|
n_batches = 0
|
|
|
|
|
start_time = time.time()
|
2025-03-14 18:02:23 +01:00
|
|
|
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
|
2024-07-04 15:36:01 +02:00
|
|
|
scores = []
|
|
|
|
|
net.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for data in inference_loader:
|
|
|
|
|
inputs, idx = data
|
|
|
|
|
|
|
|
|
|
inputs = inputs.to(self.device)
|
|
|
|
|
idx = idx.to(self.device)
|
|
|
|
|
|
|
|
|
|
outputs = net(inputs)
|
2025-03-14 18:02:23 +01:00
|
|
|
all_idx = n_batches * self.batch_size
|
|
|
|
|
all_outputs[all_idx : all_idx + len(inputs)] = (
|
|
|
|
|
outputs.cpu().data.numpy()
|
|
|
|
|
)
|
2024-07-04 15:36:01 +02:00
|
|
|
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
|
|
|
|
scores += dist.cpu().data.numpy().tolist()
|
|
|
|
|
|
|
|
|
|
n_batches += 1
|
|
|
|
|
|
|
|
|
|
self.inference_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
# Log results
|
|
|
|
|
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
|
|
|
|
logger.info("Finished inference.")
|
|
|
|
|
|
2025-03-14 18:02:23 +01:00
|
|
|
return np.array(scores), all_outputs
|
2024-07-04 15:36:01 +02:00
|
|
|
|
2025-03-14 18:02:23 +01:00
|
|
|
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
|
2024-06-28 07:42:12 +02:00
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
# Get test data loader
|
2025-03-14 18:02:23 +01:00
|
|
|
if k_fold_idx is not None:
|
|
|
|
|
_, test_loader = dataset.loaders_k_fold(
|
|
|
|
|
fold_idx=k_fold_idx,
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
num_workers=self.n_jobs_dataloader,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
_, test_loader, _ = dataset.loaders(
|
|
|
|
|
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Set device for network
|
|
|
|
|
net = net.to(self.device)
|
|
|
|
|
|
|
|
|
|
# Testing
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Starting testing...")
|
2024-06-28 07:42:12 +02:00
|
|
|
epoch_loss = 0.0
|
|
|
|
|
n_batches = 0
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
idx_label_score = []
|
|
|
|
|
net.eval()
|
2025-06-04 09:45:24 +02:00
|
|
|
net.summary(receptive_field=True)
|
2024-06-28 07:42:12 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
|
for data in test_loader:
|
2025-03-14 18:02:23 +01:00
|
|
|
inputs, labels, semi_targets, idx, _ = data
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
inputs = inputs.to(self.device)
|
|
|
|
|
labels = labels.to(self.device)
|
|
|
|
|
semi_targets = semi_targets.to(self.device)
|
|
|
|
|
idx = idx.to(self.device)
|
|
|
|
|
|
|
|
|
|
outputs = net(inputs)
|
|
|
|
|
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
2024-06-28 11:36:46 +02:00
|
|
|
losses = torch.where(
|
|
|
|
|
semi_targets == 0,
|
|
|
|
|
dist,
|
|
|
|
|
self.eta * ((dist + self.eps) ** semi_targets.float()),
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
loss = torch.mean(losses)
|
|
|
|
|
scores = dist
|
|
|
|
|
|
|
|
|
|
# Save triples of (idx, label, score) in a list
|
2024-06-28 11:36:46 +02:00
|
|
|
idx_label_score += list(
|
|
|
|
|
zip(
|
|
|
|
|
idx.cpu().data.numpy().tolist(),
|
|
|
|
|
labels.cpu().data.numpy().tolist(),
|
|
|
|
|
scores.cpu().data.numpy().tolist(),
|
|
|
|
|
)
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
epoch_loss += loss.item()
|
|
|
|
|
n_batches += 1
|
|
|
|
|
|
|
|
|
|
self.test_time = time.time() - start_time
|
|
|
|
|
self.test_scores = idx_label_score
|
|
|
|
|
|
|
|
|
|
# Compute AUC
|
|
|
|
|
_, labels, scores = zip(*idx_label_score)
|
|
|
|
|
labels = np.array(labels)
|
|
|
|
|
scores = np.array(scores)
|
|
|
|
|
self.test_auc = roc_auc_score(labels, scores)
|
2025-03-14 18:02:23 +01:00
|
|
|
self.test_roc = roc_curve(labels, scores)
|
|
|
|
|
self.test_prc = precision_recall_curve(labels, scores)
|
|
|
|
|
self.test_ap = average_precision_score(labels, scores)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Log results
|
2024-06-28 11:36:46 +02:00
|
|
|
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
|
|
|
|
logger.info("Test AUC: {:.2f}%".format(100.0 * self.test_auc))
|
|
|
|
|
logger.info("Test Time: {:.3f}s".format(self.test_time))
|
|
|
|
|
logger.info("Finished testing.")
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
|
|
|
|
|
"""Initialize hypersphere center c as the mean from an initial forward pass on the data."""
|
|
|
|
|
n_samples = 0
|
|
|
|
|
c = torch.zeros(net.rep_dim, device=self.device)
|
|
|
|
|
|
|
|
|
|
net.eval()
|
2025-06-04 09:45:24 +02:00
|
|
|
net.summary(receptive_field=True)
|
2024-06-28 07:42:12 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
|
for data in train_loader:
|
|
|
|
|
# get the inputs of the batch
|
2025-03-14 18:02:23 +01:00
|
|
|
inputs, _, _, _, _ = data
|
2024-06-28 07:42:12 +02:00
|
|
|
inputs = inputs.to(self.device)
|
|
|
|
|
outputs = net(inputs)
|
|
|
|
|
n_samples += outputs.shape[0]
|
|
|
|
|
c += torch.sum(outputs, dim=0)
|
|
|
|
|
|
|
|
|
|
c /= n_samples
|
|
|
|
|
|
|
|
|
|
# If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
|
|
|
|
|
c[(abs(c) < eps) & (c < 0)] = -eps
|
|
|
|
|
c[(abs(c) < eps) & (c > 0)] = eps
|
|
|
|
|
|
|
|
|
|
return c
|