black formatted files before changes
This commit is contained in:
@@ -2,10 +2,17 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_semisupervised_setting(labels, normal_classes, outlier_classes, known_outlier_classes,
|
||||
ratio_known_normal, ratio_known_outlier, ratio_pollution):
|
||||
def create_semisupervised_setting(
|
||||
labels,
|
||||
normal_classes,
|
||||
outlier_classes,
|
||||
known_outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
):
|
||||
"""
|
||||
Create a semi-supervised data setting.
|
||||
Create a semi-supervised data setting.
|
||||
:param labels: np.array with labels of all dataset samples
|
||||
:param normal_classes: tuple with normal class labels
|
||||
:param outlier_classes: tuple with anomaly class labels
|
||||
@@ -17,15 +24,31 @@ def create_semisupervised_setting(labels, normal_classes, outlier_classes, known
|
||||
"""
|
||||
idx_normal = np.argwhere(np.isin(labels, normal_classes)).flatten()
|
||||
idx_outlier = np.argwhere(np.isin(labels, outlier_classes)).flatten()
|
||||
idx_known_outlier_candidates = np.argwhere(np.isin(labels, known_outlier_classes)).flatten()
|
||||
idx_known_outlier_candidates = np.argwhere(
|
||||
np.isin(labels, known_outlier_classes)
|
||||
).flatten()
|
||||
|
||||
n_normal = len(idx_normal)
|
||||
|
||||
# Solve system of linear equations to obtain respective number of samples
|
||||
a = np.array([[1, 1, 0, 0],
|
||||
[(1-ratio_known_normal), -ratio_known_normal, -ratio_known_normal, -ratio_known_normal],
|
||||
[-ratio_known_outlier, -ratio_known_outlier, -ratio_known_outlier, (1-ratio_known_outlier)],
|
||||
[0, -ratio_pollution, (1-ratio_pollution), 0]])
|
||||
a = np.array(
|
||||
[
|
||||
[1, 1, 0, 0],
|
||||
[
|
||||
(1 - ratio_known_normal),
|
||||
-ratio_known_normal,
|
||||
-ratio_known_normal,
|
||||
-ratio_known_normal,
|
||||
],
|
||||
[
|
||||
-ratio_known_outlier,
|
||||
-ratio_known_outlier,
|
||||
-ratio_known_outlier,
|
||||
(1 - ratio_known_outlier),
|
||||
],
|
||||
[0, -ratio_pollution, (1 - ratio_pollution), 0],
|
||||
]
|
||||
)
|
||||
b = np.array([n_normal, 0, 0, 0])
|
||||
x = np.linalg.solve(a, b)
|
||||
|
||||
@@ -41,9 +64,13 @@ def create_semisupervised_setting(labels, normal_classes, outlier_classes, known
|
||||
perm_known_outlier = np.random.permutation(len(idx_known_outlier_candidates))
|
||||
|
||||
idx_known_normal = idx_normal[perm_normal[:n_known_normal]].tolist()
|
||||
idx_unlabeled_normal = idx_normal[perm_normal[n_known_normal:n_known_normal+n_unlabeled_normal]].tolist()
|
||||
idx_unlabeled_normal = idx_normal[
|
||||
perm_normal[n_known_normal : n_known_normal + n_unlabeled_normal]
|
||||
].tolist()
|
||||
idx_unlabeled_outlier = idx_outlier[perm_outlier[:n_unlabeled_outlier]].tolist()
|
||||
idx_known_outlier = idx_known_outlier_candidates[perm_known_outlier[:n_known_outlier]].tolist()
|
||||
idx_known_outlier = idx_known_outlier_candidates[
|
||||
perm_known_outlier[:n_known_outlier]
|
||||
].tolist()
|
||||
|
||||
# Get original class labels
|
||||
labels_known_normal = labels[idx_known_normal].tolist()
|
||||
@@ -53,14 +80,32 @@ def create_semisupervised_setting(labels, normal_classes, outlier_classes, known
|
||||
|
||||
# Get semi-supervised setting labels
|
||||
semi_labels_known_normal = np.ones(n_known_normal).astype(np.int32).tolist()
|
||||
semi_labels_unlabeled_normal = np.zeros(n_unlabeled_normal).astype(np.int32).tolist()
|
||||
semi_labels_unlabeled_outlier = np.zeros(n_unlabeled_outlier).astype(np.int32).tolist()
|
||||
semi_labels_unlabeled_normal = (
|
||||
np.zeros(n_unlabeled_normal).astype(np.int32).tolist()
|
||||
)
|
||||
semi_labels_unlabeled_outlier = (
|
||||
np.zeros(n_unlabeled_outlier).astype(np.int32).tolist()
|
||||
)
|
||||
semi_labels_known_outlier = (-np.ones(n_known_outlier).astype(np.int32)).tolist()
|
||||
|
||||
# Create final lists
|
||||
list_idx = idx_known_normal + idx_unlabeled_normal + idx_unlabeled_outlier + idx_known_outlier
|
||||
list_labels = labels_known_normal + labels_unlabeled_normal + labels_unlabeled_outlier + labels_known_outlier
|
||||
list_semi_labels = (semi_labels_known_normal + semi_labels_unlabeled_normal + semi_labels_unlabeled_outlier
|
||||
+ semi_labels_known_outlier)
|
||||
list_idx = (
|
||||
idx_known_normal
|
||||
+ idx_unlabeled_normal
|
||||
+ idx_unlabeled_outlier
|
||||
+ idx_known_outlier
|
||||
)
|
||||
list_labels = (
|
||||
labels_known_normal
|
||||
+ labels_unlabeled_normal
|
||||
+ labels_unlabeled_outlier
|
||||
+ labels_known_outlier
|
||||
)
|
||||
list_semi_labels = (
|
||||
semi_labels_known_normal
|
||||
+ semi_labels_unlabeled_normal
|
||||
+ semi_labels_unlabeled_outlier
|
||||
+ semi_labels_known_outlier
|
||||
)
|
||||
|
||||
return list_idx, list_labels, list_semi_labels
|
||||
|
||||
Reference in New Issue
Block a user