changed from thundersvm to sklearn ocsvm
This commit is contained in:
@@ -11,7 +11,7 @@ from sklearn.metrics import (
|
|||||||
roc_auc_score,
|
roc_auc_score,
|
||||||
roc_curve,
|
roc_curve,
|
||||||
)
|
)
|
||||||
from thundersvm import OneClassSVM
|
from sklearn.svm import OneClassSVM
|
||||||
|
|
||||||
from base.base_dataset import BaseADDataset
|
from base.base_dataset import BaseADDataset
|
||||||
from networks.main import build_autoencoder
|
from networks.main import build_autoencoder
|
||||||
@@ -27,7 +27,7 @@ class OCSVM(object):
|
|||||||
self.rho = None
|
self.rho = None
|
||||||
self.gamma = None
|
self.gamma = None
|
||||||
|
|
||||||
self.model = OneClassSVM(kernel=kernel, nu=nu, verbose=True, max_mem_size=4048)
|
self.model = OneClassSVM(kernel=kernel, nu=nu)
|
||||||
|
|
||||||
self.hybrid = hybrid
|
self.hybrid = hybrid
|
||||||
self.latent_space_dim = latent_space_dim
|
self.latent_space_dim = latent_space_dim
|
||||||
@@ -166,8 +166,6 @@ class OCSVM(object):
|
|||||||
kernel=self.kernel,
|
kernel=self.kernel,
|
||||||
nu=self.nu,
|
nu=self.nu,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
verbose=True,
|
|
||||||
max_mem_size=4048,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
@@ -198,7 +196,7 @@ class OCSVM(object):
|
|||||||
# If hybrid, also train a model with linear kernel
|
# If hybrid, also train a model with linear kernel
|
||||||
if self.hybrid:
|
if self.hybrid:
|
||||||
self.linear_model = OneClassSVM(
|
self.linear_model = OneClassSVM(
|
||||||
kernel="linear", nu=self.nu, max_mem_size=4048
|
kernel="linear", nu=self.nu
|
||||||
)
|
)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.linear_model.fit(X)
|
self.linear_model.fit(X)
|
||||||
@@ -479,14 +477,15 @@ class OCSVM(object):
|
|||||||
self.ae_net.to(torch.device(device))
|
self.ae_net.to(torch.device(device))
|
||||||
self.ae_net.eval()
|
self.ae_net.eval()
|
||||||
|
|
||||||
def save_model(self, export_path: Path):
|
def save_model(self, export_path):
|
||||||
"""Save OC-SVM model to export_path."""
|
"""Save OC-SVM model to export_path."""
|
||||||
self.model.save_to_file(export_path)
|
with open(export_path, "wb") as fp:
|
||||||
|
pickle.dump(self.model, fp)
|
||||||
|
|
||||||
def load_model(self, import_path: Path):
|
def load_model(self, import_path):
|
||||||
"""Load OC-SVM model from import_path."""
|
"""Load OC-SVM model from import_path."""
|
||||||
self.model.save_to_file(import_path)
|
with open(import_path, "rb") as fp:
|
||||||
pass
|
self.model = pickle.load(fp)
|
||||||
|
|
||||||
def save_results(self, export_pkl):
|
def save_results(self, export_pkl):
|
||||||
with open(export_pkl, "wb") as fp:
|
with open(export_pkl, "wb") as fp:
|
||||||
|
|||||||
@@ -597,7 +597,7 @@ def main(
|
|||||||
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
|
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
|
||||||
if train_ocsvm:
|
if train_ocsvm:
|
||||||
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
|
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
|
||||||
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin")
|
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.pkl")
|
||||||
if train_isoforest:
|
if train_isoforest:
|
||||||
Isoforest.save_results(
|
Isoforest.save_results(
|
||||||
export_pkl=xp_path + "/results_isoforest.pkl"
|
export_pkl=xp_path + "/results_isoforest.pkl"
|
||||||
@@ -616,7 +616,7 @@ def main(
|
|||||||
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
|
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
|
||||||
)
|
)
|
||||||
ocsvm.save_model(
|
ocsvm.save_model(
|
||||||
export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin"
|
export_path=xp_path + f"/model_ocsvm_{fold_idx}.pkl"
|
||||||
)
|
)
|
||||||
if train_isoforest:
|
if train_isoforest:
|
||||||
Isoforest.save_results(
|
Isoforest.save_results(
|
||||||
|
|||||||
Reference in New Issue
Block a user