ocsvm working
This commit is contained in:
@@ -85,6 +85,7 @@ class AETrainer(BaseTrainer):
|
||||
logger.info("Starting pretraining...")
|
||||
start_time = time.time()
|
||||
ae_net.train()
|
||||
ae_net.summary(receptive_field=True) # Add network summary before training
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
epoch_loss = 0.0
|
||||
@@ -197,6 +198,8 @@ class AETrainer(BaseTrainer):
|
||||
n_batches = 0
|
||||
start_time = time.time()
|
||||
ae_net.eval()
|
||||
ae_net.summary(receptive_field=True) # Add network summary before testing
|
||||
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user