full upload so not to lose anything important
This commit is contained in:
133
tools/evaluate_roc_all.py
Normal file
133
tools/evaluate_roc_all.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.stats import sem, t
|
||||
from sklearn.metrics import auc
|
||||
|
||||
|
||||
# Confidence interval function
|
||||
def confidence_interval(data, confidence=0.95):
|
||||
n = len(data)
|
||||
mean = np.mean(data)
|
||||
std_err = sem(data)
|
||||
h = std_err * t.ppf((1 + confidence) / 2.0, n - 1)
|
||||
return mean, h
|
||||
|
||||
|
||||
# Load ROC and AUC values from pickle files
|
||||
roc_data = []
|
||||
auc_scores = []
|
||||
isoforest_roc_data = []
|
||||
isoforest_auc_scores = []
|
||||
|
||||
results_path = Path(
|
||||
"/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_800_3000_new"
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
with (results_path / f"results_{i}.pkl").open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
roc_data.append(data["test_roc"])
|
||||
auc_scores.append(data["test_auc"])
|
||||
with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
isoforest_roc_data.append(data["test_roc"])
|
||||
isoforest_auc_scores.append(data["test_auc"])
|
||||
|
||||
# Calculate mean and confidence interval for DeepSAD AUC scores
|
||||
mean_auc, auc_ci = confidence_interval(auc_scores)
|
||||
|
||||
# Combine ROC curves for DeepSAD
|
||||
mean_fpr = np.linspace(0, 1, 100)
|
||||
tprs = []
|
||||
|
||||
for fpr, tpr, _ in roc_data:
|
||||
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
||||
interp_tpr[0] = 0.0
|
||||
tprs.append(interp_tpr)
|
||||
|
||||
mean_tpr = np.mean(tprs, axis=0)
|
||||
mean_tpr[-1] = 1.0
|
||||
std_tpr = np.std(tprs, axis=0)
|
||||
|
||||
# -- ADDED: Calculate mean and confidence interval for IsoForest AUC scores
|
||||
isoforest_mean_auc, isoforest_auc_ci = confidence_interval(isoforest_auc_scores)
|
||||
|
||||
# -- ADDED: Combine ROC curves for IsoForest
|
||||
isoforest_mean_fpr = np.linspace(0, 1, 100)
|
||||
isoforest_tprs = []
|
||||
|
||||
for fpr, tpr, _ in isoforest_roc_data:
|
||||
interp_tpr = np.interp(isoforest_mean_fpr, fpr, tpr)
|
||||
interp_tpr[0] = 0.0
|
||||
isoforest_tprs.append(interp_tpr)
|
||||
|
||||
isoforest_mean_tpr = np.mean(isoforest_tprs, axis=0)
|
||||
isoforest_mean_tpr[-1] = 1.0
|
||||
isoforest_std_tpr = np.std(isoforest_tprs, axis=0)
|
||||
|
||||
# Plot ROC curves with confidence margins for DeepSAD
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(
|
||||
mean_fpr,
|
||||
mean_tpr,
|
||||
color="b",
|
||||
label=f"DeepSAD Mean ROC (AUC = {mean_auc:.2f} ± {auc_ci:.2f})",
|
||||
)
|
||||
plt.fill_between(
|
||||
mean_fpr,
|
||||
mean_tpr - std_tpr,
|
||||
mean_tpr + std_tpr,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="DeepSAD ± 1 std. dev.",
|
||||
)
|
||||
|
||||
# -- ADDED: Plot ROC curves with confidence margins for IsoForest
|
||||
plt.plot(
|
||||
isoforest_mean_fpr,
|
||||
isoforest_mean_tpr,
|
||||
color="r",
|
||||
label=f"IsoForest Mean ROC (AUC = {isoforest_mean_auc:.2f} ± {isoforest_auc_ci:.2f})",
|
||||
)
|
||||
plt.fill_between(
|
||||
isoforest_mean_fpr,
|
||||
isoforest_mean_tpr - isoforest_std_tpr,
|
||||
isoforest_mean_tpr + isoforest_std_tpr,
|
||||
color="r",
|
||||
alpha=0.2,
|
||||
label="IsoForest ± 1 std. dev.",
|
||||
)
|
||||
|
||||
# Plot each fold's ROC curve (optional) for DeepSAD
|
||||
for i, (fpr, tpr, _) in enumerate(roc_data):
|
||||
plt.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
lw=1,
|
||||
alpha=0.3,
|
||||
color="b",
|
||||
label=f"DeepSAD Fold {i+1} ROC" if i == 0 else "",
|
||||
)
|
||||
|
||||
# -- ADDED: Plot each fold's ROC curve (optional) for IsoForest
|
||||
for i, (fpr, tpr, _) in enumerate(isoforest_roc_data):
|
||||
plt.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
lw=1,
|
||||
alpha=0.3,
|
||||
color="r",
|
||||
label=f"IsoForest Fold {i+1} ROC" if i == 0 else "",
|
||||
)
|
||||
|
||||
# Labels and legend
|
||||
plt.plot([0, 1], [0, 1], "k--", label="Chance")
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.title("ROC Curve with 5-Fold Cross-Validation")
|
||||
plt.legend(loc="lower right")
|
||||
plt.savefig("roc_curve_800_3000_isoforest.png")
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user