retest implemented and fixed missing center in save data

This commit is contained in:
Jan Kowalczyk
2025-07-01 17:22:29 +02:00
parent 24c6771576
commit 4863b91127
6 changed files with 189 additions and 1307 deletions

View File

@@ -126,6 +126,8 @@ class DeepSAD(object):
)
# Get the model
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
# Store training results including indices
self.results["train"]["time"] = self.trainer.train_time
self.results["train"]["indices"] = self.trainer.train_indices
@@ -333,7 +335,7 @@ class DeepSAD(object):
# load autoencoder parameters if specified
if load_ae:
if self.ae_net is None:
self.ae_net = build_autoencoder(self.net_name)
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
def save_results(self, export_pkl):