ocsvm working

This commit is contained in:
Jan Kowalczyk
2025-06-13 10:24:54 +02:00
parent d88719e718
commit 9298dea329
6 changed files with 376 additions and 137 deletions

View File

@@ -27,9 +27,11 @@ class DeepSAD(object):
ae_results: A dictionary to save the autoencoder results.
"""
def __init__(self, eta: float = 1.0):
def __init__(self, rep_dim: int, eta: float = 1.0):
"""Inits DeepSAD with hyperparameter eta."""
self.rep_dim = rep_dim # representation dimension
self.eta = eta
self.c = None # hypersphere center c
@@ -89,10 +91,10 @@ class DeepSAD(object):
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
def set_network(self, net_name, rep_dim=1024):
def set_network(self, net_name):
"""Builds the neural network phi."""
self.net_name = net_name
self.net = build_network(net_name, rep_dim=rep_dim)
self.net = build_network(net_name, self.rep_dim)
def train(
self,
@@ -240,7 +242,7 @@ class DeepSAD(object):
"""Pretrains the weights for the Deep SAD network phi via autoencoder."""
# Set autoencoder network
self.ae_net = build_autoencoder(self.net_name)
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
# Train
self.ae_optimizer_name = optimizer_name
@@ -340,7 +342,7 @@ class DeepSAD(object):
# json.dump(self.results, fp)
pickle.dump(self.results, fp)
def save_ae_results(self, export_json):
def save_ae_results(self, export_pkl):
"""Save autoencoder results dict to a JSON-file."""
with open(export_json, "w") as fp:
json.dump(self.ae_results, fp)
with open(export_pkl, "wb") as fp:
pickle.dump(self.ae_results, fp)