add torchscan for summary and receptive field (wip)

This commit is contained in:
Jan Kowalczyk
2025-06-04 09:45:24 +02:00
parent 3a0f35f21d
commit 3538b15073
5 changed files with 189 additions and 10 deletions

View File

@@ -212,6 +212,7 @@ class DeepSADTrainer(BaseTrainer):
start_time = time.time()
idx_label_score = []
net.eval()
net.summary(receptive_field=True)
with torch.no_grad():
for data in test_loader:
inputs, labels, semi_targets, idx, _ = data
@@ -267,6 +268,7 @@ class DeepSADTrainer(BaseTrainer):
c = torch.zeros(net.rep_dim, device=self.device)
net.eval()
net.summary(receptive_field=True)
with torch.no_grad():
for data in train_loader:
# get the inputs of the batch