added error for incorrect preprocessing
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def create_semisupervised_setting(
|
def create_semisupervised_setting(
|
||||||
@@ -58,6 +59,15 @@ def create_semisupervised_setting(
|
|||||||
n_unlabeled_outlier = int(x[2])
|
n_unlabeled_outlier = int(x[2])
|
||||||
n_known_outlier = int(x[3])
|
n_known_outlier = int(x[3])
|
||||||
|
|
||||||
|
if (
|
||||||
|
sum((n_known_normal, n_unlabeled_normal, n_unlabeled_outlier, n_known_outlier))
|
||||||
|
> labels.shape[0]
|
||||||
|
):
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.error(
|
||||||
|
"Given ratios for the semi-supervised setting are not possible due to data restraints. Please change the ratios or provide more/different data."
|
||||||
|
)
|
||||||
|
|
||||||
# Sample indices
|
# Sample indices
|
||||||
perm_normal = np.random.permutation(n_normal)
|
perm_normal = np.random.permutation(n_normal)
|
||||||
perm_outlier = np.random.permutation(len(idx_outlier))
|
perm_outlier = np.random.permutation(len(idx_outlier))
|
||||||
|
|||||||
Reference in New Issue
Block a user