136 lines
3.9 KiB
Python
136 lines
3.9 KiB
Python
import pickle
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from scipy.stats import sem, t
|
|
from sklearn.metrics import auc
|
|
|
|
|
|
# Confidence interval function
|
|
def confidence_interval(data, confidence=0.95):
|
|
n = len(data)
|
|
mean = np.mean(data)
|
|
std_err = sem(data)
|
|
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
|
|
return mean, h
|
|
|
|
|
|
# Load PRC (precision-recall) data and compute AP (average precision)
|
|
prc_data = []
|
|
ap_scores = []
|
|
isoforest_prc_data = []
|
|
isoforest_ap_scores = []
|
|
|
|
results_path = Path(
|
|
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_3000_800_2"
|
|
)
|
|
|
|
for i in range(5):
|
|
with (results_path / f"results_{i}.pkl").open("rb") as f:
|
|
data = pickle.load(f)
|
|
# data["test_prc"] should be (precision, recall, thresholds)
|
|
precision, recall, _ = data["test_prc"]
|
|
prc_data.append((precision, recall))
|
|
|
|
# Compute AP using area under the precision-recall curve
|
|
ap_scores.append(auc(recall, precision))
|
|
|
|
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
|
|
data = pickle.load(f)
|
|
precision, recall, _ = data["test_prc"]
|
|
isoforest_prc_data.append((precision, recall))
|
|
|
|
isoforest_ap_scores.append(auc(recall, precision))
|
|
|
|
# Calculate mean and confidence interval for DeepSAD AP scores
|
|
mean_ap, ap_ci = confidence_interval(ap_scores)
|
|
|
|
# Interpolate precision over a common recall range for DeepSAD
|
|
mean_recall = np.linspace(0, 1, 100)
|
|
precisions = []
|
|
|
|
for precision, recall in prc_data:
|
|
# Make sure recall is sorted (usually is from sklearn)
|
|
# Interpolate precision at the points in mean_recall
|
|
interp_prec = np.interp(mean_recall, np.flip(recall), np.flip(precision))
|
|
precisions.append(interp_prec)
|
|
|
|
mean_precision = np.mean(precisions, axis=0)
|
|
std_precision = np.std(precisions, axis=0)
|
|
|
|
# Calculate mean and confidence interval for IsoForest AP scores
|
|
isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores)
|
|
|
|
# Interpolate precision over a common recall range for IsoForest
|
|
isoforest_precisions = []
|
|
for precision, recall in isoforest_prc_data:
|
|
interp_prec = np.interp(mean_recall, np.flip(recall), np.flip(precision))
|
|
isoforest_precisions.append(interp_prec)
|
|
|
|
isoforest_mean_precision = np.mean(isoforest_precisions, axis=0)
|
|
isoforest_std_precision = np.std(isoforest_precisions, axis=0)
|
|
|
|
# Plot Precision-Recall curves with confidence margins
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
# DeepSAD curve
|
|
plt.plot(
|
|
mean_recall,
|
|
mean_precision,
|
|
color="b",
|
|
label=f"DeepSAD Mean PR (AP = {mean_ap:.2f} ± {ap_ci:.2f})",
|
|
)
|
|
plt.fill_between(
|
|
mean_recall,
|
|
mean_precision - std_precision,
|
|
mean_precision + std_precision,
|
|
color="b",
|
|
alpha=0.2,
|
|
label="DeepSAD ± 1 std. dev.",
|
|
)
|
|
|
|
# IsoForest curve
|
|
plt.plot(
|
|
mean_recall,
|
|
isoforest_mean_precision,
|
|
color="r",
|
|
label=f"IsoForest Mean PR (AP = {isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})",
|
|
)
|
|
plt.fill_between(
|
|
mean_recall,
|
|
isoforest_mean_precision - isoforest_std_precision,
|
|
isoforest_mean_precision + isoforest_std_precision,
|
|
color="r",
|
|
alpha=0.2,
|
|
label="IsoForest ± 1 std. dev.",
|
|
)
|
|
|
|
# Optional: plot each fold's curve for DeepSAD
|
|
for i, (precision, recall) in enumerate(prc_data):
|
|
plt.plot(
|
|
recall,
|
|
precision,
|
|
lw=1,
|
|
alpha=0.3,
|
|
color="b",
|
|
label=f"DeepSAD Fold {i + 1} PR" if i == 0 else "",
|
|
)
|
|
|
|
# Optional: plot each fold's curve for IsoForest
|
|
for i, (precision, recall) in enumerate(isoforest_prc_data):
|
|
plt.plot(
|
|
recall,
|
|
precision,
|
|
lw=1,
|
|
alpha=0.3,
|
|
color="r",
|
|
label=f"IsoForest Fold {i + 1} PR" if i == 0 else "",
|
|
)
|
|
|
|
plt.xlabel("Recall")
|
|
plt.ylabel("Precision")
|
|
plt.title("Precision-Recall Curve with 5-Fold Cross-Validation")
|
|
plt.legend(loc="upper right")
|
|
plt.savefig("pr_curve_800_3000_4.png")
|