changed from thundersvm to sklearn ocsvm
This commit is contained in:
@@ -11,7 +11,7 @@ from sklearn.metrics import (
|
||||
roc_auc_score,
|
||||
roc_curve,
|
||||
)
|
||||
from thundersvm import OneClassSVM
|
||||
from sklearn.svm import OneClassSVM
|
||||
|
||||
from base.base_dataset import BaseADDataset
|
||||
from networks.main import build_autoencoder
|
||||
@@ -27,7 +27,7 @@ class OCSVM(object):
|
||||
self.rho = None
|
||||
self.gamma = None
|
||||
|
||||
self.model = OneClassSVM(kernel=kernel, nu=nu, verbose=True, max_mem_size=4048)
|
||||
self.model = OneClassSVM(kernel=kernel, nu=nu)
|
||||
|
||||
self.hybrid = hybrid
|
||||
self.latent_space_dim = latent_space_dim
|
||||
@@ -166,8 +166,6 @@ class OCSVM(object):
|
||||
kernel=self.kernel,
|
||||
nu=self.nu,
|
||||
gamma=gamma,
|
||||
verbose=True,
|
||||
max_mem_size=4048,
|
||||
)
|
||||
|
||||
# Train
|
||||
@@ -198,7 +196,7 @@ class OCSVM(object):
|
||||
# If hybrid, also train a model with linear kernel
|
||||
if self.hybrid:
|
||||
self.linear_model = OneClassSVM(
|
||||
kernel="linear", nu=self.nu, max_mem_size=4048
|
||||
kernel="linear", nu=self.nu
|
||||
)
|
||||
start_time = time.time()
|
||||
self.linear_model.fit(X)
|
||||
@@ -479,14 +477,15 @@ class OCSVM(object):
|
||||
self.ae_net.to(torch.device(device))
|
||||
self.ae_net.eval()
|
||||
|
||||
def save_model(self, export_path: Path):
|
||||
def save_model(self, export_path):
|
||||
"""Save OC-SVM model to export_path."""
|
||||
self.model.save_to_file(export_path)
|
||||
with open(export_path, "wb") as fp:
|
||||
pickle.dump(self.model, fp)
|
||||
|
||||
def load_model(self, import_path: Path):
|
||||
def load_model(self, import_path):
|
||||
"""Load OC-SVM model from import_path."""
|
||||
self.model.save_to_file(import_path)
|
||||
pass
|
||||
with open(import_path, "rb") as fp:
|
||||
self.model = pickle.load(fp)
|
||||
|
||||
def save_results(self, export_pkl):
|
||||
with open(export_pkl, "wb") as fp:
|
||||
|
||||
@@ -597,7 +597,7 @@ def main(
|
||||
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
|
||||
if train_ocsvm:
|
||||
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
|
||||
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin")
|
||||
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.pkl")
|
||||
if train_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=xp_path + "/results_isoforest.pkl"
|
||||
@@ -616,7 +616,7 @@ def main(
|
||||
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
|
||||
)
|
||||
ocsvm.save_model(
|
||||
export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin"
|
||||
export_path=xp_path + f"/model_ocsvm_{fold_idx}.pkl"
|
||||
)
|
||||
if train_isoforest:
|
||||
Isoforest.save_results(
|
||||
|
||||
Reference in New Issue
Block a user