add torchscan for summary and receptive field (wip)

This commit is contained in:
Jan Kowalczyk
2025-06-04 09:45:24 +02:00
parent 3a0f35f21d
commit 3538b15073
5 changed files with 189 additions and 10 deletions

View File

@@ -1,6 +1,8 @@
import logging
import torch.nn as nn
import numpy as np
import torch.nn as nn
import torchscan
class BaseNet(nn.Module):
@@ -10,6 +12,7 @@ class BaseNet(nn.Module):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
self.input_dim = None # input dimensionality, i.e. dim of the input layer
def forward(self, *input):
"""
@@ -18,9 +21,17 @@ class BaseNet(nn.Module):
"""
raise NotImplementedError
def summary(self):
def summary(self, receptive_field: bool = False):
"""Network summary."""
net_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in net_parameters])
self.logger.info("Trainable parameters: {}".format(params))
self.logger.info(self)
# net_parameters = filter(lambda p: p.requires_grad, self.parameters())
# params = sum([np.prod(p.size()) for p in net_parameters])
# self.logger.info("Trainable parameters: {}".format(params))
# self.logger.info(self)
if not self.input_dim:
self.logger.warning(
"Input dimension is not set. Please set input_dim before calling summary."
)
return
self.logger.info(
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
)

View File

@@ -6,10 +6,10 @@ from base.base_net import BaseNet
class SubTer_LeNet(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.pool = nn.MaxPool2d(2, 2)
@@ -31,7 +31,6 @@ class SubTer_LeNet(BaseNet):
class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
@@ -56,10 +55,10 @@ class SubTer_LeNet_Decoder(BaseNet):
class SubTer_LeNet_Autoencoder(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.encoder = SubTer_LeNet(rep_dim=rep_dim)
self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim)

View File

@@ -212,6 +212,7 @@ class DeepSADTrainer(BaseTrainer):
start_time = time.time()
idx_label_score = []
net.eval()
net.summary(receptive_field=True)
with torch.no_grad():
for data in test_loader:
inputs, labels, semi_targets, idx, _ = data
@@ -267,6 +268,7 @@ class DeepSADTrainer(BaseTrainer):
c = torch.zeros(net.rep_dim, device=self.device)
net.eval()
net.summary(receptive_field=True)
with torch.no_grad():
for data in train_loader:
# get the inputs of the batch