Files
mt/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py

482 lines
17 KiB
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
import logging
import time
import numpy as np
2024-06-28 07:42:12 +02:00
import torch
import torch.optim as optim
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
roc_auc_score,
roc_curve,
)
from torch.utils.data.dataloader import DataLoader
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from base.base_trainer import BaseTrainer
2024-06-28 07:42:12 +02:00
class DeepSADTrainer(BaseTrainer):
2024-06-28 11:36:46 +02:00
def __init__(
self,
c,
eta: float,
optimizer_name: str = "adam",
lr: float = 0.001,
n_epochs: int = 150,
lr_milestones: tuple = (),
batch_size: int = 128,
weight_decay: float = 1e-6,
device: str = "cuda",
n_jobs_dataloader: int = 0,
):
super().__init__(
optimizer_name,
lr,
n_epochs,
lr_milestones,
batch_size,
weight_decay,
device,
n_jobs_dataloader,
)
2024-06-28 07:42:12 +02:00
# Deep SAD parameters
self.c = torch.tensor(c, device=self.device) if c is not None else None
self.eta = eta
# Optimization parameters
self.eps = 1e-6
# Results
self.train_time = None
self.test_auc = None
self.test_time = None
self.test_scores = None
2025-06-10 09:31:28 +02:00
# Add new attributes for storing indices
self.train_indices = None
self.train_file_ids = None
self.train_frame_ids = None
self.test_indices = None
self.test_file_ids = None
self.test_frame_ids = None
def train(
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
) -> BaseNet:
2024-06-28 07:42:12 +02:00
logger = logging.getLogger()
# Get train data loader
if k_fold_idx is not None:
train_loader, _ = dataset.loaders_k_fold(
fold_idx=k_fold_idx,
batch_size=self.batch_size,
num_workers=self.n_jobs_dataloader,
)
else:
train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
2024-06-28 07:42:12 +02:00
# Set device for network
net = net.to(self.device)
# Set optimizer (Adam optimizer for now)
2024-06-28 11:36:46 +02:00
optimizer = optim.Adam(
net.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
2024-06-28 07:42:12 +02:00
# Set learning rate scheduler
2024-06-28 11:36:46 +02:00
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, milestones=self.lr_milestones, gamma=0.1
)
2024-06-28 07:42:12 +02:00
# Initialize hypersphere center c (if c not loaded)
if self.c is None:
2024-06-28 11:36:46 +02:00
logger.info("Initializing center c...")
2024-06-28 07:42:12 +02:00
self.c = self.init_center_c(train_loader, net)
2024-06-28 11:36:46 +02:00
logger.info("Center c initialized.")
2024-06-28 07:42:12 +02:00
# Training
2024-06-28 11:36:46 +02:00
logger.info("Starting training...")
2024-06-28 07:42:12 +02:00
start_time = time.time()
net.train()
2025-06-10 09:31:28 +02:00
# Lists to collect all indices during training
all_indices = []
all_file_ids = []
all_frame_ids = []
2024-06-28 07:42:12 +02:00
for epoch in range(self.n_epochs):
epoch_loss = 0.0
n_batches = 0
epoch_start_time = time.time()
for data in train_loader:
2025-06-10 09:31:28 +02:00
inputs, _, _, semi_targets, idx, (file_id, frame_id) = data
inputs, semi_targets = (
inputs.to(self.device),
semi_targets.to(self.device),
2024-06-28 11:36:46 +02:00
)
2024-06-28 07:42:12 +02:00
# Zero the network parameter gradients
optimizer.zero_grad()
# Update network parameters via backpropagation: forward + backward + optimize
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
2024-06-28 11:36:46 +02:00
losses = torch.where(
semi_targets == 0,
dist,
self.eta * ((dist + self.eps) ** semi_targets.float()),
)
2024-06-28 07:42:12 +02:00
loss = torch.mean(losses)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
2025-06-10 09:31:28 +02:00
# Store indices
all_indices.extend(idx.cpu().numpy().tolist())
all_file_ids.extend(file_id.cpu().numpy().tolist())
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
scheduler.step()
if epoch in self.lr_milestones:
logger.info(
" LR scheduler: new learning rate is %g"
2024-07-04 07:39:37 +02:00
% float(scheduler.get_last_lr()[0])
)
2024-06-28 07:42:12 +02:00
# log epoch statistics
epoch_train_time = time.time() - epoch_start_time
2024-06-28 11:36:46 +02:00
logger.info(
f"| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s "
f"| Train Loss: {epoch_loss / n_batches:.6f} |"
)
2024-06-28 07:42:12 +02:00
self.train_time = time.time() - start_time
2024-06-28 11:36:46 +02:00
logger.info("Training Time: {:.3f}s".format(self.train_time))
logger.info("Finished training.")
2024-06-28 07:42:12 +02:00
2025-06-10 09:31:28 +02:00
# Store all training indices
self.train_indices = np.array(all_indices)
self.train_file_ids = np.array(all_file_ids)
self.train_frame_ids = np.array(all_frame_ids)
2024-06-28 07:42:12 +02:00
return net
2024-07-04 15:36:01 +02:00
def infer(self, dataset: BaseADDataset, net: BaseNet):
logger = logging.getLogger()
# Get test data loader
_, _, inference_loader = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
# Set device for network
net = net.to(self.device)
# Testing
logger.info("Starting inference...")
n_batches = 0
start_time = time.time()
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
2024-07-04 15:36:01 +02:00
scores = []
net.eval()
2025-06-10 09:31:28 +02:00
all_indices = []
all_file_ids = []
all_frame_ids = []
2024-07-04 15:36:01 +02:00
with torch.no_grad():
for data in inference_loader:
2025-06-10 09:31:28 +02:00
inputs, idx, (file_id, frame_id) = data
2024-07-04 15:36:01 +02:00
inputs = inputs.to(self.device)
idx = idx.to(self.device)
outputs = net(inputs)
all_idx = n_batches * self.batch_size
all_outputs[all_idx : all_idx + len(inputs)] = (
outputs.cpu().data.numpy()
)
2024-07-04 15:36:01 +02:00
dist = torch.sum((outputs - self.c) ** 2, dim=1)
scores += dist.cpu().data.numpy().tolist()
2025-06-10 09:31:28 +02:00
# Store indices
all_indices.extend(idx.cpu().numpy().tolist())
all_file_ids.extend(file_id.cpu().numpy().tolist())
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
2024-07-04 15:36:01 +02:00
n_batches += 1
self.inference_time = time.time() - start_time
# Log results
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
logger.info("Finished inference.")
2025-06-10 09:31:28 +02:00
# Store all inference indices
self.inference_indices = np.array(all_indices)
self.inference_file_ids = np.array(all_file_ids)
self.inference_frame_ids = np.array(all_frame_ids)
self.inference_index_mapping = {
"indices": self.inference_indices,
"file_ids": self.inference_file_ids,
"frame_ids": self.inference_frame_ids,
}
return np.array(scores), all_outputs
2024-07-04 15:36:01 +02:00
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
2024-06-28 07:42:12 +02:00
logger = logging.getLogger()
# Get test data loader
if k_fold_idx is not None:
_, test_loader = dataset.loaders_k_fold(
fold_idx=k_fold_idx,
batch_size=self.batch_size,
num_workers=self.n_jobs_dataloader,
)
else:
_, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
2024-06-28 07:42:12 +02:00
# Set device for network
net = net.to(self.device)
# Testing
2024-06-28 11:36:46 +02:00
logger.info("Starting testing...")
2024-06-28 07:42:12 +02:00
epoch_loss = 0.0
n_batches = 0
start_time = time.time()
2025-06-10 09:31:28 +02:00
idx_label_score_exp = []
idx_label_score_manual = []
all_labels_exp = []
all_labels_manual = []
all_scores = []
all_idx = []
# Lists to collect all indices during testing
all_indices = []
all_file_ids = []
all_frame_ids = []
2024-06-28 07:42:12 +02:00
net.eval()
net.summary(receptive_field=True)
2024-06-28 07:42:12 +02:00
with torch.no_grad():
for data in test_loader:
2025-06-10 09:31:28 +02:00
(
inputs,
labels_exp_based,
labels_manual_based,
semi_targets,
idx,
(file_id, frame_id),
) = data
2024-06-28 07:42:12 +02:00
inputs = inputs.to(self.device)
2025-06-10 09:31:28 +02:00
labels_exp_based = labels_exp_based.to(self.device)
labels_manual_based = labels_manual_based.to(self.device)
2024-06-28 07:42:12 +02:00
semi_targets = semi_targets.to(self.device)
idx = idx.to(self.device)
outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
2024-06-28 11:36:46 +02:00
losses = torch.where(
semi_targets == 0,
dist,
self.eta * ((dist + self.eps) ** semi_targets.float()),
)
2024-06-28 07:42:12 +02:00
loss = torch.mean(losses)
scores = dist
2025-06-10 09:31:28 +02:00
# Save for evaluation
idx_label_score_exp += list(
zip(
idx.cpu().data.numpy().tolist(),
labels_exp_based.cpu().data.numpy().tolist(),
scores.cpu().data.numpy().tolist(),
)
)
idx_label_score_manual += list(
2024-06-28 11:36:46 +02:00
zip(
idx.cpu().data.numpy().tolist(),
2025-06-10 09:31:28 +02:00
labels_manual_based.cpu().data.numpy().tolist(),
2024-06-28 11:36:46 +02:00
scores.cpu().data.numpy().tolist(),
)
)
2025-06-10 09:31:28 +02:00
all_labels_exp.append(labels_exp_based.cpu().numpy())
all_labels_manual.append(labels_manual_based.cpu().numpy())
all_scores.append(scores.cpu().numpy())
all_idx.append(idx.cpu().numpy())
# Store indices
all_indices.extend(idx.cpu().numpy().tolist())
all_file_ids.extend(file_id.cpu().numpy().tolist())
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
2024-06-28 07:42:12 +02:00
epoch_loss += loss.item()
n_batches += 1
self.test_time = time.time() - start_time
2025-06-10 09:31:28 +02:00
self.test_scores_exp_based = idx_label_score_exp
self.test_scores_manual_based = idx_label_score_manual
# Flatten arrays for counting and evaluation
all_labels_exp = np.concatenate(all_labels_exp)
all_labels_manual = np.concatenate(all_labels_manual)
all_scores = np.concatenate(all_scores)
all_idx = np.concatenate(all_idx)
# Count and log label stats for exp_based
n_exp_normal = np.sum(all_labels_exp == 1)
n_exp_anomaly = np.sum(all_labels_exp == -1)
n_exp_unknown = np.sum(all_labels_exp == 0)
logger.info(
f"Exp-based labels: normal(1)={n_exp_normal}, "
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
)
2024-06-28 07:42:12 +02:00
2025-06-10 09:31:28 +02:00
# Count and log label stats for manual_based
n_manual_normal = np.sum(all_labels_manual == 1)
n_manual_anomaly = np.sum(all_labels_manual == -1)
n_manual_unknown = np.sum(all_labels_manual == 0)
logger.info(
f"Manual-based labels: normal(1)={n_manual_normal}, "
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
)
# --- Evaluation for exp_based (only labeled samples) ---
idxs_exp, labels_exp, scores_exp = zip(*idx_label_score_exp)
labels_exp = np.array(labels_exp)
scores_exp = np.array(scores_exp)
# Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC
valid_mask_exp = labels_exp != 0
if np.any(valid_mask_exp):
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
scores_exp_valid = scores_exp[valid_mask_exp]
self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid)
2025-09-09 14:15:16 +02:00
self.test_roc_exp_based = roc_curve(
labels_exp_binary, scores_exp_valid, drop_intermediate=False
)
2025-06-10 09:31:28 +02:00
self.test_prc_exp_based = precision_recall_curve(
labels_exp_binary, scores_exp_valid
)
self.test_ap_exp_based = average_precision_score(
labels_exp_binary, scores_exp_valid
)
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
logger.info(
"Test AUC (exp_based): {:.2f}%".format(100.0 * self.test_auc_exp_based)
)
else:
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
self.test_auc_exp_based = None
self.test_roc_exp_based = None
self.test_prc_exp_based = None
self.test_ap_exp_based = None
2024-06-28 07:42:12 +02:00
2024-06-28 11:36:46 +02:00
logger.info("Test Time: {:.3f}s".format(self.test_time))
2025-06-10 09:31:28 +02:00
# --- Evaluation for manual_based (only labeled samples) ---
idxs_manual, labels_manual, scores_manual = zip(*idx_label_score_manual)
labels_manual = np.array(labels_manual)
scores_manual = np.array(scores_manual)
# Filter out unknown labels and convert to binary for ROC
valid_mask_manual = labels_manual != 0
if np.any(valid_mask_manual):
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
scores_manual_valid = scores_manual[valid_mask_manual]
self.test_auc_manual_based = roc_auc_score(
labels_manual_binary, scores_manual_valid
)
self.test_roc_manual_based = roc_curve(
2025-09-09 14:15:16 +02:00
labels_manual_binary, scores_manual_valid, drop_intermediate=False
2025-06-10 09:31:28 +02:00
)
self.test_prc_manual_based = precision_recall_curve(
labels_manual_binary, scores_manual_valid
)
self.test_ap_manual_based = average_precision_score(
labels_manual_binary, scores_manual_valid
)
logger.info(
"Test AUC (manual_based): {:.2f}%".format(
100.0 * self.test_auc_manual_based
)
)
else:
self.test_auc_manual_based = None
self.test_roc_manual_based = None
self.test_prc_manual_based = None
self.test_ap_manual_based = None
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
# Store all test indices
self.test_indices = np.array(all_indices)
self.test_file_ids = np.array(all_file_ids)
self.test_frame_ids = np.array(all_frame_ids)
# Add logging for indices
logger.info(f"Number of test samples: {len(self.test_indices)}")
logger.info(f"Number of unique files: {len(np.unique(self.test_file_ids))}")
# Create a mapping of indices to their file/frame information
self.test_index_mapping = {
"indices": self.test_indices,
"file_ids": self.test_file_ids,
"frame_ids": self.test_frame_ids,
"exp_based": {
"indices": np.array(idxs_exp),
"labels": np.array(labels_exp),
"scores": np.array(scores_exp),
"valid_mask": valid_mask_exp,
},
"manual_based": {
"indices": np.array(idxs_manual),
"labels": np.array(labels_manual),
"scores": np.array(scores_manual),
"valid_mask": valid_mask_manual,
},
}
2024-06-28 11:36:46 +02:00
logger.info("Finished testing.")
2024-06-28 07:42:12 +02:00
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
"""Initialize hypersphere center c as the mean from an initial forward pass on the data."""
n_samples = 0
c = torch.zeros(net.rep_dim, device=self.device)
net.eval()
net.summary(receptive_field=True)
2024-06-28 07:42:12 +02:00
with torch.no_grad():
for data in train_loader:
# get the inputs of the batch
2025-06-10 09:31:28 +02:00
inputs, _, _, _, _, _ = data
2024-06-28 07:42:12 +02:00
inputs = inputs.to(self.device)
outputs = net(inputs)
n_samples += outputs.shape[0]
c += torch.sum(outputs, dim=0)
c /= n_samples
# If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
c[(abs(c) < eps) & (c < 0)] = -eps
c[(abs(c) < eps) & (c > 0)] = eps
return c