Files
mt/tools/evaluate_prc.py

128 lines
4.3 KiB
Python
Raw Normal View History

import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem, t
from sklearn.metrics import PrecisionRecallDisplay, auc
def confidence_interval(data, confidence=0.95):
"""Compute mean and margin of error for a given list of scores."""
n = len(data)
mean = np.mean(data)
# Standard error of the mean:
std_err = sem(data)
# Confidence interval radius
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
return mean, h
# 1) LOAD PRECISION-RECALL DATA
prc_data = [] # Stores (precision, recall) for each DeepSAD fold
ap_scores = [] # Average Precision for each DeepSAD fold
isoforest_prc_data = [] # Stores (precision, recall) for each IsoForest fold
isoforest_ap_scores = [] # Average Precision for each IsoForest fold
results_path = Path(
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_800_3000_new"
)
# We assume we have 5 folds (adjust if you have a different number)
for i in range(5):
with (results_path / f"results_{i}.pkl").open("rb") as f:
data = pickle.load(f)
precision, recall, _ = data["test_prc"] # (precision, recall, thresholds)
prc_data.append((precision, recall))
# Compute Average Precision (AP) via AUC of the (recall, precision) curve
ap_scores.append(auc(recall, precision))
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
data = pickle.load(f)
precision, recall, _ = data["test_prc"]
isoforest_prc_data.append((precision, recall))
isoforest_ap_scores.append(auc(recall, precision))
# 2) CALCULATE PER-FOLD STATISTICS
mean_ap, ap_ci = confidence_interval(ap_scores)
isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores)
# 3) INTERPOLATE EACH FOLD'S PRC ON A COMMON RECALL GRID FOR MEAN CURVE
mean_recall = np.linspace(0, 1, 100)
# -- DeepSAD
deep_sad_precisions_interp = []
for precision, recall in prc_data:
# Interpolate precision values at mean_recall
interp_precision = np.interp(mean_recall, precision, recall)
deep_sad_precisions_interp.append(interp_precision)
mean_precision = np.mean(deep_sad_precisions_interp, axis=0)
std_precision = np.std(deep_sad_precisions_interp, axis=0)
# -- IsoForest
isoforest_precisions_interp = []
for precision, recall in isoforest_prc_data:
interp_precision = np.interp(mean_recall, precision, recall)
isoforest_precisions_interp.append(interp_precision)
isoforest_mean_precision = np.mean(isoforest_precisions_interp, axis=0)
isoforest_std_precision = np.std(isoforest_precisions_interp, axis=0)
# 4) CREATE PLOT USING PrecisionRecallDisplay
fig, ax = plt.subplots(figsize=(8, 6))
# (A) Plot each fold (optional) for DeepSAD
for i, (precision, recall) in enumerate(prc_data):
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
# Label only the first fold to avoid legend clutter
disp.plot(
ax=ax, color="b", alpha=0.3, label=f"DeepSAD Fold {i+1}" if i == 0 else None
)
# (B) Plot each fold (optional) for IsoForest
for i, (precision, recall) in enumerate(isoforest_prc_data):
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
disp.plot(
ax=ax, color="r", alpha=0.3, label=f"IsoForest Fold {i+1}" if i == 0 else None
)
# (C) Plot mean curve for DeepSAD
mean_disp_deepsad = PrecisionRecallDisplay(precision=mean_precision, recall=mean_recall)
mean_disp_deepsad.plot(
ax=ax, color="b", label=f"DeepSAD Mean PR (AP={mean_ap:.2f} ± {ap_ci:.2f})"
)
ax.fill_between(
mean_recall,
mean_precision - std_precision,
mean_precision + std_precision,
color="b",
alpha=0.2,
)
# (D) Plot mean curve for IsoForest
mean_disp_isoforest = PrecisionRecallDisplay(
precision=isoforest_mean_precision, recall=mean_recall
)
mean_disp_isoforest.plot(
ax=ax,
color="r",
label=f"IsoForest Mean PR (AP={isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})",
)
ax.fill_between(
mean_recall,
isoforest_mean_precision - isoforest_std_precision,
isoforest_mean_precision + isoforest_std_precision,
color="r",
alpha=0.2,
)
# 5) FINAL PLOT ADJUSTMENTS
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curve with 5-Fold Cross-Validation")
ax.legend(loc="upper right")
plt.savefig("pr_curve_800_3000_2.png")