import logging import time import numpy as np import torch import torch.optim as optim from sklearn.metrics import ( average_precision_score, precision_recall_curve, roc_auc_score, roc_curve, ) from torch.utils.data.dataloader import DataLoader from base.base_dataset import BaseADDataset from base.base_net import BaseNet from base.base_trainer import BaseTrainer class DeepSADTrainer(BaseTrainer): def __init__( self, c, eta: float, optimizer_name: str = "adam", lr: float = 0.001, n_epochs: int = 150, lr_milestones: tuple = (), batch_size: int = 128, weight_decay: float = 1e-6, device: str = "cuda", n_jobs_dataloader: int = 0, ): super().__init__( optimizer_name, lr, n_epochs, lr_milestones, batch_size, weight_decay, device, n_jobs_dataloader, ) # Deep SAD parameters self.c = torch.tensor(c, device=self.device) if c is not None else None self.eta = eta # Optimization parameters self.eps = 1e-6 # Results self.train_time = None self.test_auc = None self.test_time = None self.test_scores = None # Add new attributes for storing indices self.train_indices = None self.train_file_ids = None self.train_frame_ids = None self.test_indices = None self.test_file_ids = None self.test_frame_ids = None def train( self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None ) -> BaseNet: logger = logging.getLogger() # Get train data loader if k_fold_idx is not None: train_loader, _ = dataset.loaders_k_fold( fold_idx=k_fold_idx, batch_size=self.batch_size, num_workers=self.n_jobs_dataloader, ) else: train_loader, _, _ = dataset.loaders( batch_size=self.batch_size, num_workers=self.n_jobs_dataloader ) # Set device for network net = net.to(self.device) # Set optimizer (Adam optimizer for now) optimizer = optim.Adam( net.parameters(), lr=self.lr, weight_decay=self.weight_decay ) # Set learning rate scheduler scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestones, gamma=0.1 ) # Initialize hypersphere center c (if c not loaded) if self.c is None: logger.info("Initializing center c...") self.c = self.init_center_c(train_loader, net) logger.info("Center c initialized.") # Training logger.info("Starting training...") start_time = time.time() net.train() # Lists to collect all indices during training all_indices = [] all_file_ids = [] all_frame_ids = [] for epoch in range(self.n_epochs): epoch_loss = 0.0 n_batches = 0 epoch_start_time = time.time() for data in train_loader: inputs, _, _, semi_targets, idx, (file_id, frame_id) = data inputs, semi_targets = ( inputs.to(self.device), semi_targets.to(self.device), ) # Zero the network parameter gradients optimizer.zero_grad() # Update network parameters via backpropagation: forward + backward + optimize outputs = net(inputs) dist = torch.sum((outputs - self.c) ** 2, dim=1) losses = torch.where( semi_targets == 0, dist, self.eta * ((dist + self.eps) ** semi_targets.float()), ) loss = torch.mean(losses) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 # Store indices all_indices.extend(idx.cpu().numpy().tolist()) all_file_ids.extend(file_id.cpu().numpy().tolist()) all_frame_ids.extend(frame_id.cpu().numpy().tolist()) scheduler.step() if epoch in self.lr_milestones: logger.info( " LR scheduler: new learning rate is %g" % float(scheduler.get_last_lr()[0]) ) # log epoch statistics epoch_train_time = time.time() - epoch_start_time logger.info( f"| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s " f"| Train Loss: {epoch_loss / n_batches:.6f} |" ) self.train_time = time.time() - start_time logger.info("Training Time: {:.3f}s".format(self.train_time)) logger.info("Finished training.") # Store all training indices self.train_indices = np.array(all_indices) self.train_file_ids = np.array(all_file_ids) self.train_frame_ids = np.array(all_frame_ids) return net def infer(self, dataset: BaseADDataset, net: BaseNet): logger = logging.getLogger() # Get test data loader _, _, inference_loader = dataset.loaders( batch_size=self.batch_size, num_workers=self.n_jobs_dataloader ) # Set device for network net = net.to(self.device) # Testing logger.info("Starting inference...") n_batches = 0 start_time = time.time() all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32) scores = [] net.eval() all_indices = [] all_file_ids = [] all_frame_ids = [] with torch.no_grad(): for data in inference_loader: inputs, idx, (file_id, frame_id) = data inputs = inputs.to(self.device) idx = idx.to(self.device) outputs = net(inputs) all_idx = n_batches * self.batch_size all_outputs[all_idx : all_idx + len(inputs)] = ( outputs.cpu().data.numpy() ) dist = torch.sum((outputs - self.c) ** 2, dim=1) scores += dist.cpu().data.numpy().tolist() # Store indices all_indices.extend(idx.cpu().numpy().tolist()) all_file_ids.extend(file_id.cpu().numpy().tolist()) all_frame_ids.extend(frame_id.cpu().numpy().tolist()) n_batches += 1 self.inference_time = time.time() - start_time # Log results logger.info("Inference Time: {:.3f}s".format(self.inference_time)) logger.info("Finished inference.") # Store all inference indices self.inference_indices = np.array(all_indices) self.inference_file_ids = np.array(all_file_ids) self.inference_frame_ids = np.array(all_frame_ids) self.inference_index_mapping = { "indices": self.inference_indices, "file_ids": self.inference_file_ids, "frame_ids": self.inference_frame_ids, } return np.array(scores), all_outputs def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None): logger = logging.getLogger() # Get test data loader if k_fold_idx is not None: _, test_loader = dataset.loaders_k_fold( fold_idx=k_fold_idx, batch_size=self.batch_size, num_workers=self.n_jobs_dataloader, ) else: _, test_loader, _ = dataset.loaders( batch_size=self.batch_size, num_workers=self.n_jobs_dataloader ) # Set device for network net = net.to(self.device) # Testing logger.info("Starting testing...") epoch_loss = 0.0 n_batches = 0 start_time = time.time() idx_label_score_exp = [] idx_label_score_manual = [] all_labels_exp = [] all_labels_manual = [] all_scores = [] all_idx = [] # Lists to collect all indices during testing all_indices = [] all_file_ids = [] all_frame_ids = [] net.eval() net.summary(receptive_field=True) with torch.no_grad(): for data in test_loader: ( inputs, labels_exp_based, labels_manual_based, semi_targets, idx, (file_id, frame_id), ) = data inputs = inputs.to(self.device) labels_exp_based = labels_exp_based.to(self.device) labels_manual_based = labels_manual_based.to(self.device) semi_targets = semi_targets.to(self.device) idx = idx.to(self.device) outputs = net(inputs) dist = torch.sum((outputs - self.c) ** 2, dim=1) losses = torch.where( semi_targets == 0, dist, self.eta * ((dist + self.eps) ** semi_targets.float()), ) loss = torch.mean(losses) scores = dist # Save for evaluation idx_label_score_exp += list( zip( idx.cpu().data.numpy().tolist(), labels_exp_based.cpu().data.numpy().tolist(), scores.cpu().data.numpy().tolist(), ) ) idx_label_score_manual += list( zip( idx.cpu().data.numpy().tolist(), labels_manual_based.cpu().data.numpy().tolist(), scores.cpu().data.numpy().tolist(), ) ) all_labels_exp.append(labels_exp_based.cpu().numpy()) all_labels_manual.append(labels_manual_based.cpu().numpy()) all_scores.append(scores.cpu().numpy()) all_idx.append(idx.cpu().numpy()) # Store indices all_indices.extend(idx.cpu().numpy().tolist()) all_file_ids.extend(file_id.cpu().numpy().tolist()) all_frame_ids.extend(frame_id.cpu().numpy().tolist()) epoch_loss += loss.item() n_batches += 1 self.test_time = time.time() - start_time self.test_scores_exp_based = idx_label_score_exp self.test_scores_manual_based = idx_label_score_manual # Flatten arrays for counting and evaluation all_labels_exp = np.concatenate(all_labels_exp) all_labels_manual = np.concatenate(all_labels_manual) all_scores = np.concatenate(all_scores) all_idx = np.concatenate(all_idx) # Count and log label stats for exp_based n_exp_normal = np.sum(all_labels_exp == 1) n_exp_anomaly = np.sum(all_labels_exp == -1) n_exp_unknown = np.sum(all_labels_exp == 0) logger.info( f"Exp-based labels: normal(1)={n_exp_normal}, " f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}" ) # Count and log label stats for manual_based n_manual_normal = np.sum(all_labels_manual == 1) n_manual_anomaly = np.sum(all_labels_manual == -1) n_manual_unknown = np.sum(all_labels_manual == 0) logger.info( f"Manual-based labels: normal(1)={n_manual_normal}, " f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}" ) # --- Evaluation for exp_based (only labeled samples) --- idxs_exp, labels_exp, scores_exp = zip(*idx_label_score_exp) labels_exp = np.array(labels_exp) scores_exp = np.array(scores_exp) # Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC valid_mask_exp = labels_exp != 0 if np.any(valid_mask_exp): # Convert to binary labels for ROC (-1 → 1, 1 → 0) labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int) scores_exp_valid = scores_exp[valid_mask_exp] self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid) self.test_roc_exp_based = roc_curve(labels_exp_binary, scores_exp_valid) self.test_prc_exp_based = precision_recall_curve( labels_exp_binary, scores_exp_valid ) self.test_ap_exp_based = average_precision_score( labels_exp_binary, scores_exp_valid ) logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches)) logger.info( "Test AUC (exp_based): {:.2f}%".format(100.0 * self.test_auc_exp_based) ) else: logger.info("Test AUC (exp_based): N/A (no labeled samples)") self.test_auc_exp_based = None self.test_roc_exp_based = None self.test_prc_exp_based = None self.test_ap_exp_based = None logger.info("Test Time: {:.3f}s".format(self.test_time)) # --- Evaluation for manual_based (only labeled samples) --- idxs_manual, labels_manual, scores_manual = zip(*idx_label_score_manual) labels_manual = np.array(labels_manual) scores_manual = np.array(scores_manual) # Filter out unknown labels and convert to binary for ROC valid_mask_manual = labels_manual != 0 if np.any(valid_mask_manual): # Convert to binary labels for ROC (-1 → 1, 1 → 0) labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int) scores_manual_valid = scores_manual[valid_mask_manual] self.test_auc_manual_based = roc_auc_score( labels_manual_binary, scores_manual_valid ) self.test_roc_manual_based = roc_curve( labels_manual_binary, scores_manual_valid ) self.test_prc_manual_based = precision_recall_curve( labels_manual_binary, scores_manual_valid ) self.test_ap_manual_based = average_precision_score( labels_manual_binary, scores_manual_valid ) logger.info( "Test AUC (manual_based): {:.2f}%".format( 100.0 * self.test_auc_manual_based ) ) else: self.test_auc_manual_based = None self.test_roc_manual_based = None self.test_prc_manual_based = None self.test_ap_manual_based = None logger.info("Test AUC (manual_based): N/A (no labeled samples)") # Store all test indices self.test_indices = np.array(all_indices) self.test_file_ids = np.array(all_file_ids) self.test_frame_ids = np.array(all_frame_ids) # Add logging for indices logger.info(f"Number of test samples: {len(self.test_indices)}") logger.info(f"Number of unique files: {len(np.unique(self.test_file_ids))}") # Create a mapping of indices to their file/frame information self.test_index_mapping = { "indices": self.test_indices, "file_ids": self.test_file_ids, "frame_ids": self.test_frame_ids, "exp_based": { "indices": np.array(idxs_exp), "labels": np.array(labels_exp), "scores": np.array(scores_exp), "valid_mask": valid_mask_exp, }, "manual_based": { "indices": np.array(idxs_manual), "labels": np.array(labels_manual), "scores": np.array(scores_manual), "valid_mask": valid_mask_manual, }, } logger.info("Finished testing.") def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1): """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" n_samples = 0 c = torch.zeros(net.rep_dim, device=self.device) net.eval() net.summary(receptive_field=True) with torch.no_grad(): for data in train_loader: # get the inputs of the batch inputs, _, _, _, _, _ = data inputs = inputs.to(self.device) outputs = net(inputs) n_samples += outputs.shape[0] c += torch.sum(outputs, dim=0) c /= n_samples # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. c[(abs(c) < eps) & (c < 0)] = -eps c[(abs(c) < eps) & (c > 0)] = eps return c