wip
This commit is contained in:
@@ -51,9 +51,16 @@ class IsoForest(object):
|
||||
self.results = {
|
||||
"train_time": None,
|
||||
"test_time": None,
|
||||
"test_auc": None,
|
||||
"test_roc": None,
|
||||
"test_scores": None,
|
||||
"test_auc_exp_based": None,
|
||||
"test_roc_exp_based": None,
|
||||
"test_prc_exp_based": None,
|
||||
"test_ap_exp_based": None,
|
||||
"test_scores_exp_based": None,
|
||||
"test_auc_manual_based": None,
|
||||
"test_roc_manual_based": None,
|
||||
"test_prc_manual_based": None,
|
||||
"test_ap_manual_based": None,
|
||||
"test_scores_manual_based": None,
|
||||
}
|
||||
|
||||
def train(
|
||||
@@ -89,7 +96,7 @@ class IsoForest(object):
|
||||
# Get data from loader
|
||||
X = ()
|
||||
for data in train_loader:
|
||||
inputs, _, _, _, _ = data
|
||||
inputs, _, _, _, _, _ = data
|
||||
inputs = inputs.to(device)
|
||||
if self.hybrid:
|
||||
inputs = self.ae_net.encoder(
|
||||
@@ -133,28 +140,50 @@ class IsoForest(object):
|
||||
)
|
||||
|
||||
# Get data from loader
|
||||
idx_label_score = []
|
||||
idx_label_score_exp = []
|
||||
idx_label_score_manual = []
|
||||
X = ()
|
||||
idxs = []
|
||||
labels = []
|
||||
labels_exp = []
|
||||
labels_manual = []
|
||||
|
||||
for data in test_loader:
|
||||
inputs, label_batch, _, idx, _ = data
|
||||
inputs, label_batch, idx = (
|
||||
inputs, label_exp, label_manual, _, idx, _ = data
|
||||
inputs, label_exp, label_manual, idx = (
|
||||
inputs.to(device),
|
||||
label_batch.to(device),
|
||||
label_exp.to(device),
|
||||
label_manual.to(device),
|
||||
idx.to(device),
|
||||
)
|
||||
if self.hybrid:
|
||||
inputs = self.ae_net.encoder(
|
||||
inputs
|
||||
) # in hybrid approach, take code representation of AE as features
|
||||
X_batch = inputs.view(
|
||||
inputs.size(0), -1
|
||||
) # X_batch.shape = (batch_size, n_channels * height * width)
|
||||
inputs = self.ae_net.encoder(inputs)
|
||||
X_batch = inputs.view(inputs.size(0), -1)
|
||||
X += (X_batch.cpu().data.numpy(),)
|
||||
idxs += idx.cpu().data.numpy().astype(np.int64).tolist()
|
||||
labels += label_batch.cpu().data.numpy().astype(np.int64).tolist()
|
||||
labels_exp += label_exp.cpu().data.numpy().astype(np.int64).tolist()
|
||||
labels_manual += label_manual.cpu().data.numpy().astype(np.int64).tolist()
|
||||
|
||||
X = np.concatenate(X)
|
||||
labels_exp = np.array(labels_exp)
|
||||
labels_manual = np.array(labels_manual)
|
||||
|
||||
# Count and log label stats for exp_based
|
||||
n_exp_normal = np.sum(labels_exp == 1)
|
||||
n_exp_anomaly = np.sum(labels_exp == -1)
|
||||
n_exp_unknown = np.sum(labels_exp == 0)
|
||||
logger.info(
|
||||
f"Exp-based labels: normal(1)={n_exp_normal}, "
|
||||
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
|
||||
)
|
||||
|
||||
# Count and log label stats for manual_based
|
||||
n_manual_normal = np.sum(labels_manual == 1)
|
||||
n_manual_anomaly = np.sum(labels_manual == -1)
|
||||
n_manual_unknown = np.sum(labels_manual == 0)
|
||||
logger.info(
|
||||
f"Manual-based labels: normal(1)={n_manual_normal}, "
|
||||
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
|
||||
)
|
||||
|
||||
# Testing
|
||||
logger.info("Starting testing...")
|
||||
@@ -163,21 +192,72 @@ class IsoForest(object):
|
||||
self.results["test_time"] = time.time() - start_time
|
||||
scores = scores.flatten()
|
||||
|
||||
# Save triples of (idx, label, score) in a list
|
||||
idx_label_score += list(zip(idxs, labels, scores.tolist()))
|
||||
self.results["test_scores"] = idx_label_score
|
||||
# Save triples of (idx, label, score) in a list for both label types
|
||||
idx_label_score_exp += list(zip(idxs, labels_exp.tolist(), scores.tolist()))
|
||||
idx_label_score_manual += list(
|
||||
zip(idxs, labels_manual.tolist(), scores.tolist())
|
||||
)
|
||||
|
||||
# Compute AUC
|
||||
_, labels, scores = zip(*idx_label_score)
|
||||
labels = np.array(labels)
|
||||
scores = np.array(scores)
|
||||
self.results["test_auc"] = roc_auc_score(labels, scores)
|
||||
self.results["test_roc"] = roc_curve(labels, scores)
|
||||
self.results["test_prc"] = precision_recall_curve(labels, scores)
|
||||
self.results["test_ap"] = average_precision_score(labels, scores)
|
||||
self.results["test_scores_exp_based"] = idx_label_score_exp
|
||||
self.results["test_scores_manual_based"] = idx_label_score_manual
|
||||
|
||||
# --- Evaluation for exp_based (only labeled samples) ---
|
||||
# Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC
|
||||
valid_mask_exp = labels_exp != 0
|
||||
if np.any(valid_mask_exp):
|
||||
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
|
||||
scores_exp_valid = scores[valid_mask_exp]
|
||||
|
||||
self.results["test_auc_exp_based"] = roc_auc_score(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
self.results["test_roc_exp_based"] = roc_curve(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
self.results["test_prc_exp_based"] = precision_recall_curve(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
self.results["test_ap_exp_based"] = average_precision_score(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Test AUC (exp_based): {:.2f}%".format(
|
||||
100.0 * self.results["test_auc_exp_based"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
|
||||
|
||||
# --- Evaluation for manual_based (only labeled samples) ---
|
||||
valid_mask_manual = labels_manual != 0
|
||||
if np.any(valid_mask_manual):
|
||||
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
|
||||
scores_manual_valid = scores[valid_mask_manual]
|
||||
|
||||
self.results["test_auc_manual_based"] = roc_auc_score(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.results["test_roc_manual_based"] = roc_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.results["test_prc_manual_based"] = precision_recall_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.results["test_ap_manual_based"] = average_precision_score(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Test AUC (manual_based): {:.2f}%".format(
|
||||
100.0 * self.results["test_auc_manual_based"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
|
||||
|
||||
# Log results
|
||||
logger.info("Test AUC: {:.2f}%".format(100.0 * self.results["test_auc"]))
|
||||
logger.info("Test Time: {:.3f}s".format(self.results["test_time"]))
|
||||
logger.info("Finished testing.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user