ae elbow work

This commit is contained in:
Jan Kowalczyk
2025-06-10 13:58:38 +02:00
parent 156b6d2ac1
commit d88719e718
4 changed files with 126 additions and 82 deletions

View File

@@ -89,10 +89,10 @@ class DeepSAD(object):
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
def set_network(self, net_name):
def set_network(self, net_name, rep_dim=1024):
"""Builds the neural network phi."""
self.net_name = net_name
self.net = build_network(net_name)
self.net = build_network(net_name, rep_dim=rep_dim)
def train(
self,
@@ -256,15 +256,42 @@ class DeepSAD(object):
)
self.ae_net = self.ae_trainer.train(dataset, self.ae_net, k_fold_idx=k_fold_idx)
# Get train results
self.ae_results["train_time"] = self.ae_trainer.train_time
# Test
self.ae_trainer.test(dataset, self.ae_net, k_fold_idx=k_fold_idx)
# Get test results
self.ae_results["test_auc"] = self.ae_trainer.test_auc
self.ae_results["test_time"] = self.ae_trainer.test_time
# Get train results
self.ae_results = {
"train": {
"time": self.ae_trainer.train_time,
"indices": self.ae_trainer.train_indices,
"labels_exp_based": self.ae_trainer.train_labels_exp_based,
"labels_manual_based": self.ae_trainer.train_labels_manual_based,
"semi_targets": self.ae_trainer.train_semi_targets,
"file_ids": self.ae_trainer.train_file_ids,
"frame_ids": self.ae_trainer.train_frame_ids,
"scores": self.ae_trainer.train_scores,
"loss": self.ae_trainer.train_loss,
"file_names": {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.ae_trainer.train_file_ids)
},
},
"test": {
"time": self.ae_trainer.test_time,
"indices": self.ae_trainer.test_indices,
"labels_exp_based": self.ae_trainer.test_labels_exp_based,
"labels_manual_based": self.ae_trainer.test_labels_manual_based,
"semi_targets": self.ae_trainer.test_semi_targets,
"file_ids": self.ae_trainer.test_file_ids,
"frame_ids": self.ae_trainer.test_frame_ids,
"scores": self.ae_trainer.test_scores,
"loss": self.ae_trainer.test_loss,
"file_names": {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.ae_trainer.test_file_ids)
},
},
}
# Initialize Deep SAD network weights from pre-trained encoder
self.init_network_weights_from_pretraining()