diff --git a/Deep-SAD-PyTorch/src/baselines/ocsvm.py b/Deep-SAD-PyTorch/src/baselines/ocsvm.py index 29dfafe..fe6ca55 100644 --- a/Deep-SAD-PyTorch/src/baselines/ocsvm.py +++ b/Deep-SAD-PyTorch/src/baselines/ocsvm.py @@ -11,7 +11,7 @@ from sklearn.metrics import ( roc_auc_score, roc_curve, ) -from thundersvm import OneClassSVM +from sklearn.svm import OneClassSVM from base.base_dataset import BaseADDataset from networks.main import build_autoencoder @@ -27,7 +27,7 @@ class OCSVM(object): self.rho = None self.gamma = None - self.model = OneClassSVM(kernel=kernel, nu=nu, verbose=True, max_mem_size=4048) + self.model = OneClassSVM(kernel=kernel, nu=nu) self.hybrid = hybrid self.latent_space_dim = latent_space_dim @@ -166,8 +166,6 @@ class OCSVM(object): kernel=self.kernel, nu=self.nu, gamma=gamma, - verbose=True, - max_mem_size=4048, ) # Train @@ -198,7 +196,7 @@ class OCSVM(object): # If hybrid, also train a model with linear kernel if self.hybrid: self.linear_model = OneClassSVM( - kernel="linear", nu=self.nu, max_mem_size=4048 + kernel="linear", nu=self.nu ) start_time = time.time() self.linear_model.fit(X) @@ -479,14 +477,15 @@ class OCSVM(object): self.ae_net.to(torch.device(device)) self.ae_net.eval() - def save_model(self, export_path: Path): + def save_model(self, export_path): """Save OC-SVM model to export_path.""" - self.model.save_to_file(export_path) + with open(export_path, "wb") as fp: + pickle.dump(self.model, fp) - def load_model(self, import_path: Path): + def load_model(self, import_path): """Load OC-SVM model from import_path.""" - self.model.save_to_file(import_path) - pass + with open(import_path, "rb") as fp: + self.model = pickle.load(fp) def save_results(self, export_pkl): with open(export_pkl, "wb") as fp: diff --git a/Deep-SAD-PyTorch/src/main.py b/Deep-SAD-PyTorch/src/main.py index 72e5046..2f48168 100644 --- a/Deep-SAD-PyTorch/src/main.py +++ b/Deep-SAD-PyTorch/src/main.py @@ -597,7 +597,7 @@ def main( deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar") if train_ocsvm: ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl") - ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin") + ocsvm.save_model(export_path=xp_path + "/model_ocsvm.pkl") if train_isoforest: Isoforest.save_results( export_pkl=xp_path + "/results_isoforest.pkl" @@ -616,7 +616,7 @@ def main( export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl" ) ocsvm.save_model( - export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin" + export_path=xp_path + f"/model_ocsvm_{fold_idx}.pkl" ) if train_isoforest: Isoforest.save_results(