Files
mt/tools/evaluate_prc_2.py
2025-03-14 18:02:23 +01:00

136 lines
3.9 KiB
Python

import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem, t
from sklearn.metrics import auc
# Confidence interval function
def confidence_interval(data, confidence=0.95):
n = len(data)
mean = np.mean(data)
std_err = sem(data)
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
return mean, h
# Load PRC (precision-recall) data and compute AP (average precision)
prc_data = []
ap_scores = []
isoforest_prc_data = []
isoforest_ap_scores = []
results_path = Path(
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_3000_800_2"
)
for i in range(5):
with (results_path / f"results_{i}.pkl").open("rb") as f:
data = pickle.load(f)
# data["test_prc"] should be (precision, recall, thresholds)
precision, recall, _ = data["test_prc"]
prc_data.append((precision, recall))
# Compute AP using area under the precision-recall curve
ap_scores.append(auc(recall, precision))
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
data = pickle.load(f)
precision, recall, _ = data["test_prc"]
isoforest_prc_data.append((precision, recall))
isoforest_ap_scores.append(auc(recall, precision))
# Calculate mean and confidence interval for DeepSAD AP scores
mean_ap, ap_ci = confidence_interval(ap_scores)
# Interpolate precision over a common recall range for DeepSAD
mean_recall = np.linspace(0, 1, 100)
precisions = []
for precision, recall in prc_data:
# Make sure recall is sorted (usually is from sklearn)
# Interpolate precision at the points in mean_recall
interp_prec = np.interp(mean_recall, np.flip(recall), np.flip(precision))
precisions.append(interp_prec)
mean_precision = np.mean(precisions, axis=0)
std_precision = np.std(precisions, axis=0)
# Calculate mean and confidence interval for IsoForest AP scores
isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores)
# Interpolate precision over a common recall range for IsoForest
isoforest_precisions = []
for precision, recall in isoforest_prc_data:
interp_prec = np.interp(mean_recall, np.flip(recall), np.flip(precision))
isoforest_precisions.append(interp_prec)
isoforest_mean_precision = np.mean(isoforest_precisions, axis=0)
isoforest_std_precision = np.std(isoforest_precisions, axis=0)
# Plot Precision-Recall curves with confidence margins
plt.figure(figsize=(8, 6))
# DeepSAD curve
plt.plot(
mean_recall,
mean_precision,
color="b",
label=f"DeepSAD Mean PR (AP = {mean_ap:.2f} ± {ap_ci:.2f})",
)
plt.fill_between(
mean_recall,
mean_precision - std_precision,
mean_precision + std_precision,
color="b",
alpha=0.2,
label="DeepSAD ± 1 std. dev.",
)
# IsoForest curve
plt.plot(
mean_recall,
isoforest_mean_precision,
color="r",
label=f"IsoForest Mean PR (AP = {isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})",
)
plt.fill_between(
mean_recall,
isoforest_mean_precision - isoforest_std_precision,
isoforest_mean_precision + isoforest_std_precision,
color="r",
alpha=0.2,
label="IsoForest ± 1 std. dev.",
)
# Optional: plot each fold's curve for DeepSAD
for i, (precision, recall) in enumerate(prc_data):
plt.plot(
recall,
precision,
lw=1,
alpha=0.3,
color="b",
label=f"DeepSAD Fold {i + 1} PR" if i == 0 else "",
)
# Optional: plot each fold's curve for IsoForest
for i, (precision, recall) in enumerate(isoforest_prc_data):
plt.plot(
recall,
precision,
lw=1,
alpha=0.3,
color="r",
label=f"IsoForest Fold {i + 1} PR" if i == 0 else "",
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve with 5-Fold Cross-Validation")
plt.legend(loc="upper right")
plt.savefig("pr_curve_800_3000_4.png")