128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
|
|
import pickle
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import numpy as np
|
||
|
|
from scipy.stats import sem, t
|
||
|
|
from sklearn.metrics import PrecisionRecallDisplay, auc
|
||
|
|
|
||
|
|
|
||
|
|
def confidence_interval(data, confidence=0.95):
|
||
|
|
"""Compute mean and margin of error for a given list of scores."""
|
||
|
|
n = len(data)
|
||
|
|
mean = np.mean(data)
|
||
|
|
# Standard error of the mean:
|
||
|
|
std_err = sem(data)
|
||
|
|
# Confidence interval radius
|
||
|
|
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
|
||
|
|
return mean, h
|
||
|
|
|
||
|
|
|
||
|
|
# 1) LOAD PRECISION-RECALL DATA
|
||
|
|
prc_data = [] # Stores (precision, recall) for each DeepSAD fold
|
||
|
|
ap_scores = [] # Average Precision for each DeepSAD fold
|
||
|
|
|
||
|
|
isoforest_prc_data = [] # Stores (precision, recall) for each IsoForest fold
|
||
|
|
isoforest_ap_scores = [] # Average Precision for each IsoForest fold
|
||
|
|
|
||
|
|
results_path = Path(
|
||
|
|
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_800_3000_new"
|
||
|
|
)
|
||
|
|
|
||
|
|
# We assume we have 5 folds (adjust if you have a different number)
|
||
|
|
for i in range(5):
|
||
|
|
with (results_path / f"results_{i}.pkl").open("rb") as f:
|
||
|
|
data = pickle.load(f)
|
||
|
|
precision, recall, _ = data["test_prc"] # (precision, recall, thresholds)
|
||
|
|
prc_data.append((precision, recall))
|
||
|
|
# Compute Average Precision (AP) via AUC of the (recall, precision) curve
|
||
|
|
ap_scores.append(auc(recall, precision))
|
||
|
|
|
||
|
|
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
|
||
|
|
data = pickle.load(f)
|
||
|
|
precision, recall, _ = data["test_prc"]
|
||
|
|
isoforest_prc_data.append((precision, recall))
|
||
|
|
isoforest_ap_scores.append(auc(recall, precision))
|
||
|
|
|
||
|
|
# 2) CALCULATE PER-FOLD STATISTICS
|
||
|
|
mean_ap, ap_ci = confidence_interval(ap_scores)
|
||
|
|
isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores)
|
||
|
|
|
||
|
|
# 3) INTERPOLATE EACH FOLD'S PRC ON A COMMON RECALL GRID FOR MEAN CURVE
|
||
|
|
mean_recall = np.linspace(0, 1, 100)
|
||
|
|
|
||
|
|
# -- DeepSAD
|
||
|
|
deep_sad_precisions_interp = []
|
||
|
|
for precision, recall in prc_data:
|
||
|
|
# Interpolate precision values at mean_recall
|
||
|
|
interp_precision = np.interp(mean_recall, precision, recall)
|
||
|
|
deep_sad_precisions_interp.append(interp_precision)
|
||
|
|
|
||
|
|
mean_precision = np.mean(deep_sad_precisions_interp, axis=0)
|
||
|
|
std_precision = np.std(deep_sad_precisions_interp, axis=0)
|
||
|
|
|
||
|
|
# -- IsoForest
|
||
|
|
isoforest_precisions_interp = []
|
||
|
|
for precision, recall in isoforest_prc_data:
|
||
|
|
interp_precision = np.interp(mean_recall, precision, recall)
|
||
|
|
isoforest_precisions_interp.append(interp_precision)
|
||
|
|
|
||
|
|
isoforest_mean_precision = np.mean(isoforest_precisions_interp, axis=0)
|
||
|
|
isoforest_std_precision = np.std(isoforest_precisions_interp, axis=0)
|
||
|
|
|
||
|
|
# 4) CREATE PLOT USING PrecisionRecallDisplay
|
||
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||
|
|
|
||
|
|
# (A) Plot each fold (optional) for DeepSAD
|
||
|
|
for i, (precision, recall) in enumerate(prc_data):
|
||
|
|
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||
|
|
# Label only the first fold to avoid legend clutter
|
||
|
|
disp.plot(
|
||
|
|
ax=ax, color="b", alpha=0.3, label=f"DeepSAD Fold {i+1}" if i == 0 else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# (B) Plot each fold (optional) for IsoForest
|
||
|
|
for i, (precision, recall) in enumerate(isoforest_prc_data):
|
||
|
|
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||
|
|
disp.plot(
|
||
|
|
ax=ax, color="r", alpha=0.3, label=f"IsoForest Fold {i+1}" if i == 0 else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# (C) Plot mean curve for DeepSAD
|
||
|
|
mean_disp_deepsad = PrecisionRecallDisplay(precision=mean_precision, recall=mean_recall)
|
||
|
|
mean_disp_deepsad.plot(
|
||
|
|
ax=ax, color="b", label=f"DeepSAD Mean PR (AP={mean_ap:.2f} ± {ap_ci:.2f})"
|
||
|
|
)
|
||
|
|
ax.fill_between(
|
||
|
|
mean_recall,
|
||
|
|
mean_precision - std_precision,
|
||
|
|
mean_precision + std_precision,
|
||
|
|
color="b",
|
||
|
|
alpha=0.2,
|
||
|
|
)
|
||
|
|
|
||
|
|
# (D) Plot mean curve for IsoForest
|
||
|
|
mean_disp_isoforest = PrecisionRecallDisplay(
|
||
|
|
precision=isoforest_mean_precision, recall=mean_recall
|
||
|
|
)
|
||
|
|
mean_disp_isoforest.plot(
|
||
|
|
ax=ax,
|
||
|
|
color="r",
|
||
|
|
label=f"IsoForest Mean PR (AP={isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})",
|
||
|
|
)
|
||
|
|
ax.fill_between(
|
||
|
|
mean_recall,
|
||
|
|
isoforest_mean_precision - isoforest_std_precision,
|
||
|
|
isoforest_mean_precision + isoforest_std_precision,
|
||
|
|
color="r",
|
||
|
|
alpha=0.2,
|
||
|
|
)
|
||
|
|
|
||
|
|
# 5) FINAL PLOT ADJUSTMENTS
|
||
|
|
ax.set_xlabel("Recall")
|
||
|
|
ax.set_ylabel("Precision")
|
||
|
|
ax.set_title("Precision-Recall Curve with 5-Fold Cross-Validation")
|
||
|
|
ax.legend(loc="upper right")
|
||
|
|
|
||
|
|
plt.savefig("pr_curve_800_3000_2.png")
|