2nd subter network arch
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_receptive_field
|
||||
|
||||
from base.base_net import BaseNet
|
||||
|
||||
@@ -29,6 +30,13 @@ class SubTer_LeNet(BaseNet):
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
def summary(self, receptive_field: bool = False):
|
||||
# first run super method to log parameters and structure
|
||||
super().summary(receptive_field=receptive_field)
|
||||
self.logger.info("torch_receptive_field:")
|
||||
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
|
||||
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
|
||||
|
||||
|
||||
class SubTer_LeNet_Decoder(BaseNet):
|
||||
def __init__(self, rep_dim=1024):
|
||||
|
||||
Reference in New Issue
Block a user