wip
This commit is contained in:
@@ -366,7 +366,9 @@ class DeepSADTrainer(BaseTrainer):
|
||||
scores_exp_valid = scores_exp[valid_mask_exp]
|
||||
|
||||
self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid)
|
||||
self.test_roc_exp_based = roc_curve(labels_exp_binary, scores_exp_valid)
|
||||
self.test_roc_exp_based = roc_curve(
|
||||
labels_exp_binary, scores_exp_valid, drop_intermediate=False
|
||||
)
|
||||
self.test_prc_exp_based = precision_recall_curve(
|
||||
labels_exp_binary, scores_exp_valid
|
||||
)
|
||||
@@ -403,7 +405,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
labels_manual_binary, scores_manual_valid
|
||||
)
|
||||
self.test_roc_manual_based = roc_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
labels_manual_binary, scores_manual_valid, drop_intermediate=False
|
||||
)
|
||||
self.test_prc_manual_based = precision_recall_curve(
|
||||
labels_manual_binary, scores_manual_valid
|
||||
|
||||
Reference in New Issue
Block a user