ocsvm working
This commit is contained in:
@@ -27,9 +27,11 @@ class DeepSAD(object):
|
||||
ae_results: A dictionary to save the autoencoder results.
|
||||
"""
|
||||
|
||||
def __init__(self, eta: float = 1.0):
|
||||
def __init__(self, rep_dim: int, eta: float = 1.0):
|
||||
"""Inits DeepSAD with hyperparameter eta."""
|
||||
|
||||
self.rep_dim = rep_dim # representation dimension
|
||||
|
||||
self.eta = eta
|
||||
self.c = None # hypersphere center c
|
||||
|
||||
@@ -89,10 +91,10 @@ class DeepSAD(object):
|
||||
|
||||
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
||||
|
||||
def set_network(self, net_name, rep_dim=1024):
|
||||
def set_network(self, net_name):
|
||||
"""Builds the neural network phi."""
|
||||
self.net_name = net_name
|
||||
self.net = build_network(net_name, rep_dim=rep_dim)
|
||||
self.net = build_network(net_name, self.rep_dim)
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -240,7 +242,7 @@ class DeepSAD(object):
|
||||
"""Pretrains the weights for the Deep SAD network phi via autoencoder."""
|
||||
|
||||
# Set autoencoder network
|
||||
self.ae_net = build_autoencoder(self.net_name)
|
||||
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
|
||||
|
||||
# Train
|
||||
self.ae_optimizer_name = optimizer_name
|
||||
@@ -340,7 +342,7 @@ class DeepSAD(object):
|
||||
# json.dump(self.results, fp)
|
||||
pickle.dump(self.results, fp)
|
||||
|
||||
def save_ae_results(self, export_json):
|
||||
def save_ae_results(self, export_pkl):
|
||||
"""Save autoencoder results dict to a JSON-file."""
|
||||
with open(export_json, "w") as fp:
|
||||
json.dump(self.ae_results, fp)
|
||||
with open(export_pkl, "wb") as fp:
|
||||
pickle.dump(self.ae_results, fp)
|
||||
|
||||
Reference in New Issue
Block a user