full upload so not to lose anything important
This commit is contained in:
127
tools/evaluate_prc.py
Normal file
127
tools/evaluate_prc.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.stats import sem, t
|
||||
from sklearn.metrics import PrecisionRecallDisplay, auc
|
||||
|
||||
|
||||
def confidence_interval(data, confidence=0.95):
|
||||
"""Compute mean and margin of error for a given list of scores."""
|
||||
n = len(data)
|
||||
mean = np.mean(data)
|
||||
# Standard error of the mean:
|
||||
std_err = sem(data)
|
||||
# Confidence interval radius
|
||||
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
|
||||
return mean, h
|
||||
|
||||
|
||||
# 1) LOAD PRECISION-RECALL DATA
|
||||
prc_data = [] # Stores (precision, recall) for each DeepSAD fold
|
||||
ap_scores = [] # Average Precision for each DeepSAD fold
|
||||
|
||||
isoforest_prc_data = [] # Stores (precision, recall) for each IsoForest fold
|
||||
isoforest_ap_scores = [] # Average Precision for each IsoForest fold
|
||||
|
||||
results_path = Path(
|
||||
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_800_3000_new"
|
||||
)
|
||||
|
||||
# We assume we have 5 folds (adjust if you have a different number)
|
||||
for i in range(5):
|
||||
with (results_path / f"results_{i}.pkl").open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
precision, recall, _ = data["test_prc"] # (precision, recall, thresholds)
|
||||
prc_data.append((precision, recall))
|
||||
# Compute Average Precision (AP) via AUC of the (recall, precision) curve
|
||||
ap_scores.append(auc(recall, precision))
|
||||
|
||||
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
precision, recall, _ = data["test_prc"]
|
||||
isoforest_prc_data.append((precision, recall))
|
||||
isoforest_ap_scores.append(auc(recall, precision))
|
||||
|
||||
# 2) CALCULATE PER-FOLD STATISTICS
|
||||
mean_ap, ap_ci = confidence_interval(ap_scores)
|
||||
isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores)
|
||||
|
||||
# 3) INTERPOLATE EACH FOLD'S PRC ON A COMMON RECALL GRID FOR MEAN CURVE
|
||||
mean_recall = np.linspace(0, 1, 100)
|
||||
|
||||
# -- DeepSAD
|
||||
deep_sad_precisions_interp = []
|
||||
for precision, recall in prc_data:
|
||||
# Interpolate precision values at mean_recall
|
||||
interp_precision = np.interp(mean_recall, precision, recall)
|
||||
deep_sad_precisions_interp.append(interp_precision)
|
||||
|
||||
mean_precision = np.mean(deep_sad_precisions_interp, axis=0)
|
||||
std_precision = np.std(deep_sad_precisions_interp, axis=0)
|
||||
|
||||
# -- IsoForest
|
||||
isoforest_precisions_interp = []
|
||||
for precision, recall in isoforest_prc_data:
|
||||
interp_precision = np.interp(mean_recall, precision, recall)
|
||||
isoforest_precisions_interp.append(interp_precision)
|
||||
|
||||
isoforest_mean_precision = np.mean(isoforest_precisions_interp, axis=0)
|
||||
isoforest_std_precision = np.std(isoforest_precisions_interp, axis=0)
|
||||
|
||||
# 4) CREATE PLOT USING PrecisionRecallDisplay
|
||||
fig, ax = plt.subplots(figsize=(8, 6))
|
||||
|
||||
# (A) Plot each fold (optional) for DeepSAD
|
||||
for i, (precision, recall) in enumerate(prc_data):
|
||||
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||||
# Label only the first fold to avoid legend clutter
|
||||
disp.plot(
|
||||
ax=ax, color="b", alpha=0.3, label=f"DeepSAD Fold {i+1}" if i == 0 else None
|
||||
)
|
||||
|
||||
# (B) Plot each fold (optional) for IsoForest
|
||||
for i, (precision, recall) in enumerate(isoforest_prc_data):
|
||||
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||||
disp.plot(
|
||||
ax=ax, color="r", alpha=0.3, label=f"IsoForest Fold {i+1}" if i == 0 else None
|
||||
)
|
||||
|
||||
# (C) Plot mean curve for DeepSAD
|
||||
mean_disp_deepsad = PrecisionRecallDisplay(precision=mean_precision, recall=mean_recall)
|
||||
mean_disp_deepsad.plot(
|
||||
ax=ax, color="b", label=f"DeepSAD Mean PR (AP={mean_ap:.2f} ± {ap_ci:.2f})"
|
||||
)
|
||||
ax.fill_between(
|
||||
mean_recall,
|
||||
mean_precision - std_precision,
|
||||
mean_precision + std_precision,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
)
|
||||
|
||||
# (D) Plot mean curve for IsoForest
|
||||
mean_disp_isoforest = PrecisionRecallDisplay(
|
||||
precision=isoforest_mean_precision, recall=mean_recall
|
||||
)
|
||||
mean_disp_isoforest.plot(
|
||||
ax=ax,
|
||||
color="r",
|
||||
label=f"IsoForest Mean PR (AP={isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})",
|
||||
)
|
||||
ax.fill_between(
|
||||
mean_recall,
|
||||
isoforest_mean_precision - isoforest_std_precision,
|
||||
isoforest_mean_precision + isoforest_std_precision,
|
||||
color="r",
|
||||
alpha=0.2,
|
||||
)
|
||||
|
||||
# 5) FINAL PLOT ADJUSTMENTS
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.set_title("Precision-Recall Curve with 5-Fold Cross-Validation")
|
||||
ax.legend(loc="upper right")
|
||||
|
||||
plt.savefig("pr_curve_800_3000_2.png")
|
||||
Reference in New Issue
Block a user