2nd subter network arch

This commit is contained in:
Jan Kowalczyk
2025-06-17 07:26:03 +02:00
parent 9298dea329
commit bbd093da0c
9 changed files with 248 additions and 30 deletions

View File

@@ -1,6 +1,5 @@
import logging
import numpy as np
import torch.nn as nn
import torchscan
@@ -32,8 +31,5 @@ class BaseNet(nn.Module):
"Input dimension is not set. Please set input_dim before calling summary."
)
return
self.logger.info(
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
)
module_info = torchscan.crawl_module(self, self.input_dim)
pass
self.logger.info("torchscan:\n")
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)

View File

@@ -55,6 +55,7 @@ from utils.visualization.plot_images_grid import plot_images_grid
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"fmnist_LeNet",
"cifar10_LeNet",

View File

@@ -5,6 +5,7 @@ from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
from .mlp import MLP, MLP_Autoencoder
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
from .subter_LeNet_rf import SubTer_Efficient_AE, SubTer_EfficientEncoder
from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder
from .vae import VariationalAutoencoder
@@ -16,6 +17,7 @@ def build_network(net_name, rep_dim, ae_net=None):
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"mnist_DGM_M2",
"mnist_DGM_M1M2",
@@ -48,6 +50,9 @@ def build_network(net_name, rep_dim, ae_net=None):
if net_name == "subter_LeNet":
net = SubTer_LeNet(rep_dim=rep_dim)
if net_name == "subter_efficient":
net = SubTer_EfficientEncoder(rep_dim=rep_dim)
if net_name == "subter_LeNet_Split":
net = SubTer_LeNet_Split()
@@ -135,6 +140,7 @@ def build_autoencoder(net_name, rep_dim):
implemented_networks = (
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"mnist_LeNet",
"mnist_DGM_M1M2",
@@ -160,6 +166,9 @@ def build_autoencoder(net_name, rep_dim):
if net_name == "subter_LeNet":
ae_net = SubTer_LeNet_Autoencoder(rep_dim=rep_dim)
if net_name == "subter_efficient":
ae_net = SubTer_Efficient_AE(rep_dim=rep_dim)
if net_name == "subter_LeNet_Split":
ae_net = SubTer_LeNet_Split_Autoencoder()

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_receptive_field
from base.base_net import BaseNet
@@ -29,6 +30,13 @@ class SubTer_LeNet(BaseNet):
x = self.fc1(x)
return x
def summary(self, receptive_field: bool = False):
# first run super method to log parameters and structure
super().summary(receptive_field=receptive_field)
self.logger.info("torch_receptive_field:")
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024):

View File

@@ -0,0 +1,124 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_receptive_field
from base.base_net import BaseNet
# ---------------------- helper ---------------------------------------------
def circ_pad_x(x, pad_w):
"""Circular pad on width (azimuth) only."""
return F.pad(x, (pad_w, pad_w, 0, 0), mode="circular")
class DWSeparableConv(nn.Module):
"""Depthwise separable 3×17 conv + 1×1 pointwise + optional channel shuffle."""
def __init__(self, c_in: int, c_out: int, shuffle: bool = False):
super().__init__()
self.dw = nn.Conv2d(
c_in, c_in, kernel_size=(3, 17), padding=(1, 0), groups=c_in, bias=False
)
self.pw = nn.Conv2d(c_in, c_out, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(c_out, eps=1e-4, affine=False)
self.shuffle = shuffle
def _shuffle(self, x):
if x.size(1) % 2 != 0:
return x # can't shuffle odd channels
b, c, h, w = x.shape
x = x.view(b, 2, c // 2, h, w)
x = torch.transpose(x, 1, 2).contiguous()
return x.view(b, c, h, w)
def forward(self, x):
x = circ_pad_x(x, 8)
x = self.dw(x)
x = self.pw(x)
if self.shuffle:
x = self._shuffle(x)
return F.leaky_relu(self.bn(x), 0.1)
def summary(self, receptive_field: bool = False):
# first run super method to log parameters and structure
super().summary(receptive_field=receptive_field)
self.logger.info("torch_receptive_field:")
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
# ---------------------- encoder --------------------------------------------
class SubTer_EfficientEncoder(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.input_dim = (1, 32, 2048)
self.rep_dim = rep_dim
self.conv1 = DWSeparableConv(1, 16)
self.pool_h4 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4)) # 2048 ➔ 512
self.conv2 = DWSeparableConv(16, 32, shuffle=True)
self.pool2 = nn.MaxPool2d(2, 2) # 32 ➔ 16 vertically, 512 ➔ 256 horizontally
self.pool3 = nn.MaxPool2d(2, 2) # 16 ➔ 8 , 256 ➔ 128
self.squeeze = nn.Conv2d(32, 8, 1, bias=False)
self.fc = nn.Linear(8 * 8 * 128, rep_dim, bias=False)
def forward(self, x):
x = x.view(-1, 1, 32, 2048)
x = self.conv1(x)
x = self.pool_h4(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.pool3(x)
x = self.squeeze(x)
return self.fc(x.flatten(1))
# ---------------------- decoder (NN upsample) ------------------------------
class SubTer_EfficientDecoder(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.fc = nn.Linear(rep_dim, 8 * 8 * 128, bias=False)
self.expand = nn.Conv2d(8, 32, 1, bias=False)
self.rep_dim = rep_dim
# Nearestneighbour upsampling layers
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
DWSeparableConv(32, 32, shuffle=True),
) # 8×128 ➔ 16×256
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=(1, 4), mode="nearest"),
DWSeparableConv(32, 16),
) # 16×256 ➔ 16×1024
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=(2, 2), mode="nearest"),
DWSeparableConv(16, 8),
) # 16×1024 ➔ 32×2048
self.out_conv = nn.Conv2d(8, 1, kernel_size=(3, 17), padding=(1, 0), bias=False)
def forward(self, x):
x = self.fc(x).view(x.size(0), 8, 8, 128)
x = self.expand(x)
x = self.up1(x)
x = self.up2(x)
x = self.up3(x)
x = circ_pad_x(x, 8)
return torch.sigmoid(self.out_conv(x))
# ---------------------- autoencoder wrapper -------------------------------
class SubTer_Efficient_AE(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.encoder = SubTer_EfficientEncoder(rep_dim)
self.decoder = SubTer_EfficientDecoder(rep_dim)
def forward(self, x):
return self.decoder(self.encoder(x))