diff --git a/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py b/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py index 599fdc1..57d273d 100644 --- a/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py +++ b/Deep-SAD-PyTorch/src/optim/DeepSAD_trainer.py @@ -83,13 +83,6 @@ class DeepSADTrainer(BaseTrainer): net.train() for epoch in range(self.n_epochs): - scheduler.step() - if epoch in self.lr_milestones: - logger.info( - " LR scheduler: new learning rate is %g" - % float(scheduler.get_lr()[0]) - ) - epoch_loss = 0.0 n_batches = 0 epoch_start_time = time.time() @@ -117,6 +110,13 @@ class DeepSADTrainer(BaseTrainer): epoch_loss += loss.item() n_batches += 1 + scheduler.step() + if epoch in self.lr_milestones: + logger.info( + " LR scheduler: new learning rate is %g" + % float(scheduler.get_lr()[0]) + ) + # log epoch statistics epoch_train_time = time.time() - epoch_start_time logger.info( diff --git a/Deep-SAD-PyTorch/src/optim/ae_trainer.py b/Deep-SAD-PyTorch/src/optim/ae_trainer.py index e612e25..e8a8d30 100644 --- a/Deep-SAD-PyTorch/src/optim/ae_trainer.py +++ b/Deep-SAD-PyTorch/src/optim/ae_trainer.py @@ -71,13 +71,6 @@ class AETrainer(BaseTrainer): ae_net.train() for epoch in range(self.n_epochs): - scheduler.step() - if epoch in self.lr_milestones: - logger.info( - " LR scheduler: new learning rate is %g" - % float(scheduler.get_lr()[0]) - ) - epoch_loss = 0.0 n_batches = 0 epoch_start_time = time.time() @@ -98,6 +91,13 @@ class AETrainer(BaseTrainer): epoch_loss += loss.item() n_batches += 1 + scheduler.step() + if epoch in self.lr_milestones: + logger.info( + " LR scheduler: new learning rate is %g" + % float(scheduler.get_last_lr()[0]) + ) + # log epoch statistics epoch_train_time = time.time() - epoch_start_time logger.info(