From 61424bf0532701213dc534f6a0ba10fb451e572e Mon Sep 17 00:00:00 2001 From: Jan Kowalczyk Date: Wed, 3 Jul 2024 17:39:32 +0200 Subject: [PATCH] added error for incorrect preprocessing --- Deep-SAD-PyTorch/src/datasets/preprocessing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/Deep-SAD-PyTorch/src/datasets/preprocessing.py b/Deep-SAD-PyTorch/src/datasets/preprocessing.py index cc98b8c..1329313 100644 --- a/Deep-SAD-PyTorch/src/datasets/preprocessing.py +++ b/Deep-SAD-PyTorch/src/datasets/preprocessing.py @@ -1,5 +1,6 @@ import torch import numpy as np +import logging def create_semisupervised_setting( @@ -58,6 +59,15 @@ def create_semisupervised_setting( n_unlabeled_outlier = int(x[2]) n_known_outlier = int(x[3]) + if ( + sum((n_known_normal, n_unlabeled_normal, n_unlabeled_outlier, n_known_outlier)) + > labels.shape[0] + ): + logger = logging.getLogger() + logger.error( + "Given ratios for the semi-supervised setting are not possible due to data restraints. Please change the ratios or provide more/different data." + ) + # Sample indices perm_normal = np.random.permutation(n_normal) perm_outlier = np.random.permutation(len(idx_outlier))