2nd subter network arch
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torchscan
|
||||
|
||||
@@ -32,8 +31,5 @@ class BaseNet(nn.Module):
|
||||
"Input dimension is not set. Please set input_dim before calling summary."
|
||||
)
|
||||
return
|
||||
self.logger.info(
|
||||
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
|
||||
)
|
||||
module_info = torchscan.crawl_module(self, self.input_dim)
|
||||
pass
|
||||
self.logger.info("torchscan:\n")
|
||||
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
|
||||
|
||||
@@ -55,6 +55,7 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_efficient",
|
||||
"subter_LeNet_Split",
|
||||
"fmnist_LeNet",
|
||||
"cifar10_LeNet",
|
||||
|
||||
@@ -5,6 +5,7 @@ from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
||||
from .mlp import MLP, MLP_Autoencoder
|
||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
||||
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||||
from .subter_LeNet_rf import SubTer_Efficient_AE, SubTer_EfficientEncoder
|
||||
from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder
|
||||
from .vae import VariationalAutoencoder
|
||||
|
||||
@@ -16,6 +17,7 @@ def build_network(net_name, rep_dim, ae_net=None):
|
||||
"mnist_LeNet",
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_efficient",
|
||||
"subter_LeNet_Split",
|
||||
"mnist_DGM_M2",
|
||||
"mnist_DGM_M1M2",
|
||||
@@ -48,6 +50,9 @@ def build_network(net_name, rep_dim, ae_net=None):
|
||||
if net_name == "subter_LeNet":
|
||||
net = SubTer_LeNet(rep_dim=rep_dim)
|
||||
|
||||
if net_name == "subter_efficient":
|
||||
net = SubTer_EfficientEncoder(rep_dim=rep_dim)
|
||||
|
||||
if net_name == "subter_LeNet_Split":
|
||||
net = SubTer_LeNet_Split()
|
||||
|
||||
@@ -135,6 +140,7 @@ def build_autoencoder(net_name, rep_dim):
|
||||
implemented_networks = (
|
||||
"elpv_LeNet",
|
||||
"subter_LeNet",
|
||||
"subter_efficient",
|
||||
"subter_LeNet_Split",
|
||||
"mnist_LeNet",
|
||||
"mnist_DGM_M1M2",
|
||||
@@ -160,6 +166,9 @@ def build_autoencoder(net_name, rep_dim):
|
||||
if net_name == "subter_LeNet":
|
||||
ae_net = SubTer_LeNet_Autoencoder(rep_dim=rep_dim)
|
||||
|
||||
if net_name == "subter_efficient":
|
||||
ae_net = SubTer_Efficient_AE(rep_dim=rep_dim)
|
||||
|
||||
if net_name == "subter_LeNet_Split":
|
||||
ae_net = SubTer_LeNet_Split_Autoencoder()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_receptive_field
|
||||
|
||||
from base.base_net import BaseNet
|
||||
|
||||
@@ -29,6 +30,13 @@ class SubTer_LeNet(BaseNet):
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
def summary(self, receptive_field: bool = False):
|
||||
# first run super method to log parameters and structure
|
||||
super().summary(receptive_field=receptive_field)
|
||||
self.logger.info("torch_receptive_field:")
|
||||
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
|
||||
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
|
||||
|
||||
|
||||
class SubTer_LeNet_Decoder(BaseNet):
|
||||
def __init__(self, rep_dim=1024):
|
||||
|
||||
124
Deep-SAD-PyTorch/src/networks/subter_LeNet_rf.py
Normal file
124
Deep-SAD-PyTorch/src/networks/subter_LeNet_rf.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_receptive_field
|
||||
|
||||
from base.base_net import BaseNet
|
||||
|
||||
# ---------------------- helper ---------------------------------------------
|
||||
|
||||
|
||||
def circ_pad_x(x, pad_w):
|
||||
"""Circular pad on width (azimuth) only."""
|
||||
return F.pad(x, (pad_w, pad_w, 0, 0), mode="circular")
|
||||
|
||||
|
||||
class DWSeparableConv(nn.Module):
|
||||
"""Depth‑wise separable 3×17 conv + 1×1 point‑wise + optional channel shuffle."""
|
||||
|
||||
def __init__(self, c_in: int, c_out: int, shuffle: bool = False):
|
||||
super().__init__()
|
||||
self.dw = nn.Conv2d(
|
||||
c_in, c_in, kernel_size=(3, 17), padding=(1, 0), groups=c_in, bias=False
|
||||
)
|
||||
self.pw = nn.Conv2d(c_in, c_out, kernel_size=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(c_out, eps=1e-4, affine=False)
|
||||
self.shuffle = shuffle
|
||||
|
||||
def _shuffle(self, x):
|
||||
if x.size(1) % 2 != 0:
|
||||
return x # can't shuffle odd channels
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(b, 2, c // 2, h, w)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
return x.view(b, c, h, w)
|
||||
|
||||
def forward(self, x):
|
||||
x = circ_pad_x(x, 8)
|
||||
x = self.dw(x)
|
||||
x = self.pw(x)
|
||||
if self.shuffle:
|
||||
x = self._shuffle(x)
|
||||
return F.leaky_relu(self.bn(x), 0.1)
|
||||
|
||||
def summary(self, receptive_field: bool = False):
|
||||
# first run super method to log parameters and structure
|
||||
super().summary(receptive_field=receptive_field)
|
||||
self.logger.info("torch_receptive_field:")
|
||||
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
|
||||
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
|
||||
|
||||
|
||||
# ---------------------- encoder --------------------------------------------
|
||||
class SubTer_EfficientEncoder(BaseNet):
|
||||
def __init__(self, rep_dim: int = 512):
|
||||
super().__init__()
|
||||
self.input_dim = (1, 32, 2048)
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
self.conv1 = DWSeparableConv(1, 16)
|
||||
self.pool_h4 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4)) # 2048 ➔ 512
|
||||
self.conv2 = DWSeparableConv(16, 32, shuffle=True)
|
||||
self.pool2 = nn.MaxPool2d(2, 2) # 32 ➔ 16 vertically, 512 ➔ 256 horizontally
|
||||
self.pool3 = nn.MaxPool2d(2, 2) # 16 ➔ 8 , 256 ➔ 128
|
||||
self.squeeze = nn.Conv2d(32, 8, 1, bias=False)
|
||||
self.fc = nn.Linear(8 * 8 * 128, rep_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 32, 2048)
|
||||
x = self.conv1(x)
|
||||
x = self.pool_h4(x)
|
||||
x = self.conv2(x)
|
||||
x = self.pool2(x)
|
||||
x = self.pool3(x)
|
||||
x = self.squeeze(x)
|
||||
return self.fc(x.flatten(1))
|
||||
|
||||
|
||||
# ---------------------- decoder (NN upsample) ------------------------------
|
||||
class SubTer_EfficientDecoder(BaseNet):
|
||||
def __init__(self, rep_dim: int = 512):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(rep_dim, 8 * 8 * 128, bias=False)
|
||||
self.expand = nn.Conv2d(8, 32, 1, bias=False)
|
||||
self.rep_dim = rep_dim
|
||||
|
||||
# Nearest‑neighbour upsampling layers
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||
DWSeparableConv(32, 32, shuffle=True),
|
||||
) # 8×128 ➔ 16×256
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(1, 4), mode="nearest"),
|
||||
DWSeparableConv(32, 16),
|
||||
) # 16×256 ➔ 16×1024
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(2, 2), mode="nearest"),
|
||||
DWSeparableConv(16, 8),
|
||||
) # 16×1024 ➔ 32×2048
|
||||
|
||||
self.out_conv = nn.Conv2d(8, 1, kernel_size=(3, 17), padding=(1, 0), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x).view(x.size(0), 8, 8, 128)
|
||||
x = self.expand(x)
|
||||
x = self.up1(x)
|
||||
x = self.up2(x)
|
||||
x = self.up3(x)
|
||||
x = circ_pad_x(x, 8)
|
||||
return torch.sigmoid(self.out_conv(x))
|
||||
|
||||
|
||||
# ---------------------- autoencoder wrapper -------------------------------
|
||||
class SubTer_Efficient_AE(BaseNet):
|
||||
def __init__(self, rep_dim: int = 512):
|
||||
super().__init__()
|
||||
self.input_dim = (1, 32, 2048) # Input dimension for the network
|
||||
self.rep_dim = rep_dim
|
||||
self.encoder = SubTer_EfficientEncoder(rep_dim)
|
||||
self.decoder = SubTer_EfficientDecoder(rep_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.encoder(x))
|
||||
Reference in New Issue
Block a user