implemented inference
This commit is contained in:
@@ -96,7 +96,9 @@ class IsoForest(object):
|
||||
"""Tests the Isolation Forest model on the test data."""
|
||||
logger = logging.getLogger()
|
||||
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Get data from loader
|
||||
idx_label_score = []
|
||||
|
||||
@@ -108,7 +108,9 @@ class KDE(object):
|
||||
"""Tests the Kernel Density Estimation model on the test data."""
|
||||
logger = logging.getLogger()
|
||||
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Get data from loader
|
||||
idx_label_score = []
|
||||
|
||||
@@ -77,7 +77,9 @@ class OCSVM(object):
|
||||
best_auc = 0.0
|
||||
|
||||
# Sample hold-out set from test set
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
X_test = ()
|
||||
labels = []
|
||||
@@ -163,7 +165,9 @@ class OCSVM(object):
|
||||
"""Tests the OC-SVM model on the test data."""
|
||||
logger = logging.getLogger()
|
||||
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Get data from loader
|
||||
idx_label_score = []
|
||||
|
||||
@@ -91,7 +91,9 @@ class SSAD(object):
|
||||
best_auc = 0.0
|
||||
|
||||
# Sample hold-out set from test set
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
X_test = ()
|
||||
labels = []
|
||||
@@ -190,7 +192,9 @@ class SSAD(object):
|
||||
"""Tests the SSAD model on the test data."""
|
||||
logger = logging.getLogger()
|
||||
|
||||
_, test_loader = dataset.loaders(batch_size=128, num_workers=n_jobs_dataloader)
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=128, num_workers=n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Get data from loader
|
||||
idx_label_score = []
|
||||
|
||||
Reference in New Issue
Block a user