This commit is contained in:
Jan Kowalczyk
2025-09-09 14:15:16 +02:00
parent ed80faf1e2
commit 86d9d96ca4
12 changed files with 725 additions and 14 deletions

View File

@@ -366,7 +366,9 @@ class DeepSADTrainer(BaseTrainer):
scores_exp_valid = scores_exp[valid_mask_exp]
self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid)
self.test_roc_exp_based = roc_curve(labels_exp_binary, scores_exp_valid)
self.test_roc_exp_based = roc_curve(
labels_exp_binary, scores_exp_valid, drop_intermediate=False
)
self.test_prc_exp_based = precision_recall_curve(
labels_exp_binary, scores_exp_valid
)
@@ -403,7 +405,7 @@ class DeepSADTrainer(BaseTrainer):
labels_manual_binary, scores_manual_valid
)
self.test_roc_manual_based = roc_curve(
labels_manual_binary, scores_manual_valid
labels_manual_binary, scores_manual_valid, drop_intermediate=False
)
self.test_prc_manual_based = precision_recall_curve(
labels_manual_binary, scores_manual_valid