ae elbow work
This commit is contained in:
@@ -89,10 +89,10 @@ class DeepSAD(object):
|
||||
|
||||
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
||||
|
||||
def set_network(self, net_name):
|
||||
def set_network(self, net_name, rep_dim=1024):
|
||||
"""Builds the neural network phi."""
|
||||
self.net_name = net_name
|
||||
self.net = build_network(net_name)
|
||||
self.net = build_network(net_name, rep_dim=rep_dim)
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -256,15 +256,42 @@ class DeepSAD(object):
|
||||
)
|
||||
self.ae_net = self.ae_trainer.train(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get train results
|
||||
self.ae_results["train_time"] = self.ae_trainer.train_time
|
||||
|
||||
# Test
|
||||
self.ae_trainer.test(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
||||
|
||||
# Get test results
|
||||
self.ae_results["test_auc"] = self.ae_trainer.test_auc
|
||||
self.ae_results["test_time"] = self.ae_trainer.test_time
|
||||
# Get train results
|
||||
self.ae_results = {
|
||||
"train": {
|
||||
"time": self.ae_trainer.train_time,
|
||||
"indices": self.ae_trainer.train_indices,
|
||||
"labels_exp_based": self.ae_trainer.train_labels_exp_based,
|
||||
"labels_manual_based": self.ae_trainer.train_labels_manual_based,
|
||||
"semi_targets": self.ae_trainer.train_semi_targets,
|
||||
"file_ids": self.ae_trainer.train_file_ids,
|
||||
"frame_ids": self.ae_trainer.train_frame_ids,
|
||||
"scores": self.ae_trainer.train_scores,
|
||||
"loss": self.ae_trainer.train_loss,
|
||||
"file_names": {
|
||||
file_id: dataset.get_file_name_from_idx(file_id)
|
||||
for file_id in np.unique(self.ae_trainer.train_file_ids)
|
||||
},
|
||||
},
|
||||
"test": {
|
||||
"time": self.ae_trainer.test_time,
|
||||
"indices": self.ae_trainer.test_indices,
|
||||
"labels_exp_based": self.ae_trainer.test_labels_exp_based,
|
||||
"labels_manual_based": self.ae_trainer.test_labels_manual_based,
|
||||
"semi_targets": self.ae_trainer.test_semi_targets,
|
||||
"file_ids": self.ae_trainer.test_file_ids,
|
||||
"frame_ids": self.ae_trainer.test_frame_ids,
|
||||
"scores": self.ae_trainer.test_scores,
|
||||
"loss": self.ae_trainer.test_loss,
|
||||
"file_names": {
|
||||
file_id: dataset.get_file_name_from_idx(file_id)
|
||||
for file_id in np.unique(self.ae_trainer.test_file_ids)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize Deep SAD network weights from pre-trained encoder
|
||||
self.init_network_weights_from_pretraining()
|
||||
|
||||
Reference in New Issue
Block a user