wip
This commit is contained in:
@@ -55,6 +55,15 @@ class DeepSADTrainer(BaseTrainer):
|
||||
self.test_time = None
|
||||
self.test_scores = None
|
||||
|
||||
# Add new attributes for storing indices
|
||||
self.train_indices = None
|
||||
self.train_file_ids = None
|
||||
self.train_frame_ids = None
|
||||
|
||||
self.test_indices = None
|
||||
self.test_file_ids = None
|
||||
self.test_frame_ids = None
|
||||
|
||||
def train(
|
||||
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
|
||||
) -> BaseNet:
|
||||
@@ -95,12 +104,18 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger.info("Starting training...")
|
||||
start_time = time.time()
|
||||
net.train()
|
||||
|
||||
# Lists to collect all indices during training
|
||||
all_indices = []
|
||||
all_file_ids = []
|
||||
all_frame_ids = []
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
epoch_loss = 0.0
|
||||
n_batches = 0
|
||||
epoch_start_time = time.time()
|
||||
for data in train_loader:
|
||||
inputs, _, semi_targets, _, _ = data
|
||||
inputs, _, _, semi_targets, idx, (file_id, frame_id) = data
|
||||
inputs, semi_targets = (
|
||||
inputs.to(self.device),
|
||||
semi_targets.to(self.device),
|
||||
@@ -124,6 +139,11 @@ class DeepSADTrainer(BaseTrainer):
|
||||
epoch_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
# Store indices
|
||||
all_indices.extend(idx.cpu().numpy().tolist())
|
||||
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||
|
||||
scheduler.step()
|
||||
if epoch in self.lr_milestones:
|
||||
logger.info(
|
||||
@@ -142,6 +162,11 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger.info("Training Time: {:.3f}s".format(self.train_time))
|
||||
logger.info("Finished training.")
|
||||
|
||||
# Store all training indices
|
||||
self.train_indices = np.array(all_indices)
|
||||
self.train_file_ids = np.array(all_file_ids)
|
||||
self.train_frame_ids = np.array(all_frame_ids)
|
||||
|
||||
return net
|
||||
|
||||
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
||||
@@ -162,9 +187,14 @@ class DeepSADTrainer(BaseTrainer):
|
||||
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
|
||||
scores = []
|
||||
net.eval()
|
||||
|
||||
all_indices = []
|
||||
all_file_ids = []
|
||||
all_frame_ids = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in inference_loader:
|
||||
inputs, idx = data
|
||||
inputs, idx, (file_id, frame_id) = data
|
||||
|
||||
inputs = inputs.to(self.device)
|
||||
idx = idx.to(self.device)
|
||||
@@ -177,6 +207,11 @@ class DeepSADTrainer(BaseTrainer):
|
||||
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
||||
scores += dist.cpu().data.numpy().tolist()
|
||||
|
||||
# Store indices
|
||||
all_indices.extend(idx.cpu().numpy().tolist())
|
||||
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||
|
||||
n_batches += 1
|
||||
|
||||
self.inference_time = time.time() - start_time
|
||||
@@ -185,6 +220,17 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
||||
logger.info("Finished inference.")
|
||||
|
||||
# Store all inference indices
|
||||
self.inference_indices = np.array(all_indices)
|
||||
self.inference_file_ids = np.array(all_file_ids)
|
||||
self.inference_frame_ids = np.array(all_frame_ids)
|
||||
|
||||
self.inference_index_mapping = {
|
||||
"indices": self.inference_indices,
|
||||
"file_ids": self.inference_file_ids,
|
||||
"frame_ids": self.inference_frame_ids,
|
||||
}
|
||||
|
||||
return np.array(scores), all_outputs
|
||||
|
||||
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
|
||||
@@ -210,15 +256,34 @@ class DeepSADTrainer(BaseTrainer):
|
||||
epoch_loss = 0.0
|
||||
n_batches = 0
|
||||
start_time = time.time()
|
||||
idx_label_score = []
|
||||
idx_label_score_exp = []
|
||||
idx_label_score_manual = []
|
||||
all_labels_exp = []
|
||||
all_labels_manual = []
|
||||
all_scores = []
|
||||
all_idx = []
|
||||
|
||||
# Lists to collect all indices during testing
|
||||
all_indices = []
|
||||
all_file_ids = []
|
||||
all_frame_ids = []
|
||||
|
||||
net.eval()
|
||||
net.summary(receptive_field=True)
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
inputs, labels, semi_targets, idx, _ = data
|
||||
(
|
||||
inputs,
|
||||
labels_exp_based,
|
||||
labels_manual_based,
|
||||
semi_targets,
|
||||
idx,
|
||||
(file_id, frame_id),
|
||||
) = data
|
||||
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
labels_exp_based = labels_exp_based.to(self.device)
|
||||
labels_manual_based = labels_manual_based.to(self.device)
|
||||
semi_targets = semi_targets.to(self.device)
|
||||
idx = idx.to(self.device)
|
||||
|
||||
@@ -232,34 +297,161 @@ class DeepSADTrainer(BaseTrainer):
|
||||
loss = torch.mean(losses)
|
||||
scores = dist
|
||||
|
||||
# Save triples of (idx, label, score) in a list
|
||||
idx_label_score += list(
|
||||
# Save for evaluation
|
||||
idx_label_score_exp += list(
|
||||
zip(
|
||||
idx.cpu().data.numpy().tolist(),
|
||||
labels.cpu().data.numpy().tolist(),
|
||||
labels_exp_based.cpu().data.numpy().tolist(),
|
||||
scores.cpu().data.numpy().tolist(),
|
||||
)
|
||||
)
|
||||
idx_label_score_manual += list(
|
||||
zip(
|
||||
idx.cpu().data.numpy().tolist(),
|
||||
labels_manual_based.cpu().data.numpy().tolist(),
|
||||
scores.cpu().data.numpy().tolist(),
|
||||
)
|
||||
)
|
||||
all_labels_exp.append(labels_exp_based.cpu().numpy())
|
||||
all_labels_manual.append(labels_manual_based.cpu().numpy())
|
||||
all_scores.append(scores.cpu().numpy())
|
||||
all_idx.append(idx.cpu().numpy())
|
||||
|
||||
# Store indices
|
||||
all_indices.extend(idx.cpu().numpy().tolist())
|
||||
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||
|
||||
epoch_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
self.test_time = time.time() - start_time
|
||||
self.test_scores = idx_label_score
|
||||
self.test_scores_exp_based = idx_label_score_exp
|
||||
self.test_scores_manual_based = idx_label_score_manual
|
||||
|
||||
# Compute AUC
|
||||
_, labels, scores = zip(*idx_label_score)
|
||||
labels = np.array(labels)
|
||||
scores = np.array(scores)
|
||||
self.test_auc = roc_auc_score(labels, scores)
|
||||
self.test_roc = roc_curve(labels, scores)
|
||||
self.test_prc = precision_recall_curve(labels, scores)
|
||||
self.test_ap = average_precision_score(labels, scores)
|
||||
# Flatten arrays for counting and evaluation
|
||||
all_labels_exp = np.concatenate(all_labels_exp)
|
||||
all_labels_manual = np.concatenate(all_labels_manual)
|
||||
all_scores = np.concatenate(all_scores)
|
||||
all_idx = np.concatenate(all_idx)
|
||||
|
||||
# Count and log label stats for exp_based
|
||||
n_exp_normal = np.sum(all_labels_exp == 1)
|
||||
n_exp_anomaly = np.sum(all_labels_exp == -1)
|
||||
n_exp_unknown = np.sum(all_labels_exp == 0)
|
||||
logger.info(
|
||||
f"Exp-based labels: normal(1)={n_exp_normal}, "
|
||||
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
|
||||
)
|
||||
|
||||
# Count and log label stats for manual_based
|
||||
n_manual_normal = np.sum(all_labels_manual == 1)
|
||||
n_manual_anomaly = np.sum(all_labels_manual == -1)
|
||||
n_manual_unknown = np.sum(all_labels_manual == 0)
|
||||
logger.info(
|
||||
f"Manual-based labels: normal(1)={n_manual_normal}, "
|
||||
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
|
||||
)
|
||||
|
||||
# --- Evaluation for exp_based (only labeled samples) ---
|
||||
idxs_exp, labels_exp, scores_exp = zip(*idx_label_score_exp)
|
||||
labels_exp = np.array(labels_exp)
|
||||
scores_exp = np.array(scores_exp)
|
||||
|
||||
# Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC
|
||||
valid_mask_exp = labels_exp != 0
|
||||
if np.any(valid_mask_exp):
|
||||
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
|
||||
scores_exp_valid = scores_exp[valid_mask_exp]
|
||||
|
||||
self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid)
|
||||
self.test_roc_exp_based = roc_curve(labels_exp_binary, scores_exp_valid)
|
||||
self.test_prc_exp_based = precision_recall_curve(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
self.test_ap_exp_based = average_precision_score(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
|
||||
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
||||
logger.info(
|
||||
"Test AUC (exp_based): {:.2f}%".format(100.0 * self.test_auc_exp_based)
|
||||
)
|
||||
else:
|
||||
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
|
||||
self.test_auc_exp_based = None
|
||||
self.test_roc_exp_based = None
|
||||
self.test_prc_exp_based = None
|
||||
self.test_ap_exp_based = None
|
||||
|
||||
# Log results
|
||||
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
||||
logger.info("Test AUC: {:.2f}%".format(100.0 * self.test_auc))
|
||||
logger.info("Test Time: {:.3f}s".format(self.test_time))
|
||||
|
||||
# --- Evaluation for manual_based (only labeled samples) ---
|
||||
idxs_manual, labels_manual, scores_manual = zip(*idx_label_score_manual)
|
||||
labels_manual = np.array(labels_manual)
|
||||
scores_manual = np.array(scores_manual)
|
||||
|
||||
# Filter out unknown labels and convert to binary for ROC
|
||||
valid_mask_manual = labels_manual != 0
|
||||
if np.any(valid_mask_manual):
|
||||
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
|
||||
scores_manual_valid = scores_manual[valid_mask_manual]
|
||||
|
||||
self.test_auc_manual_based = roc_auc_score(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.test_roc_manual_based = roc_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.test_prc_manual_based = precision_recall_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.test_ap_manual_based = average_precision_score(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Test AUC (manual_based): {:.2f}%".format(
|
||||
100.0 * self.test_auc_manual_based
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.test_auc_manual_based = None
|
||||
self.test_roc_manual_based = None
|
||||
self.test_prc_manual_based = None
|
||||
self.test_ap_manual_based = None
|
||||
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
|
||||
|
||||
# Store all test indices
|
||||
self.test_indices = np.array(all_indices)
|
||||
self.test_file_ids = np.array(all_file_ids)
|
||||
self.test_frame_ids = np.array(all_frame_ids)
|
||||
|
||||
# Add logging for indices
|
||||
logger.info(f"Number of test samples: {len(self.test_indices)}")
|
||||
logger.info(f"Number of unique files: {len(np.unique(self.test_file_ids))}")
|
||||
|
||||
# Create a mapping of indices to their file/frame information
|
||||
self.test_index_mapping = {
|
||||
"indices": self.test_indices,
|
||||
"file_ids": self.test_file_ids,
|
||||
"frame_ids": self.test_frame_ids,
|
||||
"exp_based": {
|
||||
"indices": np.array(idxs_exp),
|
||||
"labels": np.array(labels_exp),
|
||||
"scores": np.array(scores_exp),
|
||||
"valid_mask": valid_mask_exp,
|
||||
},
|
||||
"manual_based": {
|
||||
"indices": np.array(idxs_manual),
|
||||
"labels": np.array(labels_manual),
|
||||
"scores": np.array(scores_manual),
|
||||
"valid_mask": valid_mask_manual,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info("Finished testing.")
|
||||
|
||||
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
|
||||
@@ -272,7 +464,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
with torch.no_grad():
|
||||
for data in train_loader:
|
||||
# get the inputs of the batch
|
||||
inputs, _, _, _, _ = data
|
||||
inputs, _, _, _, _, _ = data
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = net(inputs)
|
||||
n_samples += outputs.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user