add torchscan for summary and receptive field (wip)

This commit is contained in:
Jan Kowalczyk
2025-06-04 09:45:24 +02:00
parent 3a0f35f21d
commit 3538b15073
5 changed files with 189 additions and 10 deletions

View File

@@ -6,10 +6,10 @@ from base.base_net import BaseNet
class SubTer_LeNet(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.pool = nn.MaxPool2d(2, 2)
@@ -31,7 +31,6 @@ class SubTer_LeNet(BaseNet):
class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
@@ -56,10 +55,10 @@ class SubTer_LeNet_Decoder(BaseNet):
class SubTer_LeNet_Autoencoder(BaseNet):
def __init__(self, rep_dim=1024):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.encoder = SubTer_LeNet(rep_dim=rep_dim)
self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim)