full upload so not to lose anything important

This commit is contained in:
Jan Kowalczyk
2025-03-14 18:02:23 +01:00
parent 35fcfb7d5a
commit b824ff7482
33 changed files with 3539 additions and 353 deletions

82
tools/evaluate_roc.py Normal file
View File

@@ -0,0 +1,82 @@
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem, t
from sklearn.metrics import auc
# Confidence interval function
def confidence_interval(data, confidence=0.95):
n = len(data)
mean = np.mean(data)
std_err = sem(data)
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
return mean, h
# Load ROC and AUC values from pickle files
roc_data = []
auc_scores = []
isoforest_roc_data = []
isoforest_auc_scores = []
results_path = Path(
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_0_0"
)
for i in range(5):
with (results_path / f"results_{i}.pkl").open("rb") as f:
data = pickle.load(f)
roc_data.append(data["test_roc"])
auc_scores.append(data["test_auc"])
with (results_path / f"results.isoforest_{i}.pkl").open("rb") as f:
data = pickle.load(f)
isoforest_roc_data.append(data["test_roc"])
isoforest_auc_scores.append(data["test_auc"])
# Calculate mean and confidence interval for AUC scores
mean_auc, auc_ci = confidence_interval(auc_scores)
# Combine ROC curves
mean_fpr = np.linspace(0, 1, 100)
tprs = []
for fpr, tpr, _ in roc_data:
interp_tpr = np.interp(mean_fpr, fpr, tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
std_tpr = np.std(tprs, axis=0)
# Plot ROC curves with confidence margins
plt.figure()
plt.plot(
mean_fpr,
mean_tpr,
color="b",
label=f"Mean ROC (AUC = {mean_auc:.2f} ± {auc_ci:.2f})",
)
plt.fill_between(
mean_fpr,
mean_tpr - std_tpr,
mean_tpr + std_tpr,
color="b",
alpha=0.2,
label="± 1 std. dev.",
)
# Plot each fold's ROC curve (optional)
for i, (fpr, tpr, _) in enumerate(roc_data):
plt.plot(fpr, tpr, lw=1, alpha=0.3, label=f"Fold {i + 1} ROC")
# Labels and legend
plt.plot([0, 1], [0, 1], "k--", label="Chance")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve with 5-Fold Cross-Validation")
plt.legend(loc="lower right")
plt.savefig("roc_curve_0_0.png")