ocsvm working
This commit is contained in:
@@ -9,7 +9,7 @@ from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencod
|
||||
from .vae import VariationalAutoencoder
|
||||
|
||||
|
||||
def build_network(net_name, ae_net=None, rep_dim=1024):
|
||||
def build_network(net_name, rep_dim, ae_net=None):
|
||||
"""Builds the neural network."""
|
||||
|
||||
implemented_networks = (
|
||||
@@ -129,7 +129,7 @@ def build_network(net_name, ae_net=None, rep_dim=1024):
|
||||
return net
|
||||
|
||||
|
||||
def build_autoencoder(net_name):
|
||||
def build_autoencoder(net_name, rep_dim):
|
||||
"""Builds the corresponding autoencoder network."""
|
||||
|
||||
implemented_networks = (
|
||||
@@ -158,7 +158,7 @@ def build_autoencoder(net_name):
|
||||
ae_net = MNIST_LeNet_Autoencoder()
|
||||
|
||||
if net_name == "subter_LeNet":
|
||||
ae_net = SubTer_LeNet_Autoencoder()
|
||||
ae_net = SubTer_LeNet_Autoencoder(rep_dim=rep_dim)
|
||||
|
||||
if net_name == "subter_LeNet_Split":
|
||||
ae_net = SubTer_LeNet_Split_Autoencoder()
|
||||
|
||||
Reference in New Issue
Block a user