implemented inference

This commit is contained in:
Jan Kowalczyk
2024-07-04 15:36:01 +02:00
parent 745efbb8f5
commit 5014c41b24
13 changed files with 384 additions and 177 deletions

View File

@@ -96,7 +96,9 @@ class IsoForest(object):
"""Tests the Isolation Forest model on the test data."""
logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader
idx_label_score = []

View File

@@ -108,7 +108,9 @@ class KDE(object):
"""Tests the Kernel Density Estimation model on the test data."""
logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader
idx_label_score = []

View File

@@ -77,7 +77,9 @@ class OCSVM(object):
best_auc = 0.0
# Sample hold-out set from test set
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
X_test = ()
labels = []
@@ -163,7 +165,9 @@ class OCSVM(object):
"""Tests the OC-SVM model on the test data."""
logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader
idx_label_score = []

View File

@@ -91,7 +91,9 @@ class SSAD(object):
best_auc = 0.0
# Sample hold-out set from test set
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
X_test = ()
labels = []
@@ -190,7 +192,9 @@ class SSAD(object):
"""Tests the SSAD model on the test data."""
logger = logging.getLogger()
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
_, test_loader, _ = dataset.loaders(
batch_size=128, num_workers=n_jobs_dataloader
)
# Get data from loader
idx_label_score = []