2025-08-13 15:04:44 +02:00
|
|
|
|
# subter_lenet_arch.py
|
|
|
|
|
|
# Requires running from inside the PlotNeuralNet repo, like: python3 ../subter_lenet_arch.py
|
|
|
|
|
|
import sys, argparse
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append("../") # import pycore from repo root
|
|
|
|
|
|
|
|
|
|
|
|
from pycore.tikzeng import *
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument("--rep_dim", type=int, default=1024, help="latent size for FC")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
REP = int(args.rep_dim)
|
|
|
|
|
|
|
|
|
|
|
|
# Visual scales so the huge width doesn't dominate the figure
|
|
|
|
|
|
H32, H16, H8, H1 = 26, 18, 12, 1
|
|
|
|
|
|
D2048, D1024, D512, D256, D128, D1 = 52, 36, 24, 12, 6, 1
|
|
|
|
|
|
W1, W4, W8, W16, W32 = 1, 2, 2, 4, 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arch = [
|
|
|
|
|
|
to_head(".."),
|
|
|
|
|
|
to_cor(),
|
|
|
|
|
|
to_begin(),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"latent",
|
|
|
|
|
|
n_filer="",
|
|
|
|
|
|
s_filer="latent dim",
|
|
|
|
|
|
offset="(2,0,0)",
|
|
|
|
|
|
to="(0,0,0)",
|
|
|
|
|
|
height=H8 * 1.6,
|
|
|
|
|
|
depth=D1,
|
|
|
|
|
|
width=W1,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
caption="Latent Space",
|
|
|
|
|
|
captionshift=0,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
# to_connection("fc1", "latent"),
|
|
|
|
|
|
# --------------------------- DECODER ---------------------------
|
|
|
|
|
|
# FC back to 16384
|
|
|
|
|
|
to_fc(
|
|
|
|
|
|
"fc3",
|
|
|
|
|
|
n_filer="{{8×128×8}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.5,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
offset="(2,-.5,0)",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(latent-east)",
|
|
|
|
|
|
height=H1,
|
|
|
|
|
|
depth=D512,
|
|
|
|
|
|
width=W1,
|
|
|
|
|
|
caption=f"FC",
|
2025-09-28 14:35:10 +02:00
|
|
|
|
captionshift=20,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"unsqueeze",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
s_filer="{{128×8}}",
|
|
|
|
|
|
zlabeloffset=0.4,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=32,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
offset="(1.4,0,0)",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(fc3-east)",
|
|
|
|
|
|
height=H8,
|
|
|
|
|
|
depth=D128,
|
|
|
|
|
|
width=W32,
|
2025-09-01 18:53:01 +02:00
|
|
|
|
caption="Unsqueeze",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
# to_connection("latent", "fc3"),
|
|
|
|
|
|
# Reshape to 4×8×512
|
|
|
|
|
|
to_UnPool(
|
|
|
|
|
|
"up1",
|
2025-09-28 14:35:10 +02:00
|
|
|
|
offset="(1.2,0,0)",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
n_filer=32,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(unsqueeze-east)",
|
|
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D256,
|
|
|
|
|
|
width=W32,
|
|
|
|
|
|
caption="",
|
|
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv1",
|
|
|
|
|
|
s_filer="",
|
|
|
|
|
|
n_filer=1,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(up1-east)",
|
|
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D256,
|
|
|
|
|
|
width=W1,
|
2025-09-01 18:53:01 +02:00
|
|
|
|
caption="Deconv1",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv2",
|
|
|
|
|
|
s_filer="{{256×16}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.4,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=32,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(dwdeconv1-east)",
|
|
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D256,
|
|
|
|
|
|
width=W32,
|
|
|
|
|
|
caption="",
|
|
|
|
|
|
),
|
|
|
|
|
|
to_UnPool(
|
|
|
|
|
|
"up2",
|
|
|
|
|
|
offset="(2,0,0)",
|
|
|
|
|
|
to="(dwdeconv2-east)",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
n_filer=32,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D1024,
|
|
|
|
|
|
width=W32,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
caption="Deconv2",
|
|
|
|
|
|
captionshift=20,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv3",
|
|
|
|
|
|
s_filer="",
|
|
|
|
|
|
n_filer=1,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(up2-east)",
|
|
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D1024,
|
|
|
|
|
|
width=W1,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
caption="",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv4",
|
|
|
|
|
|
s_filer="{{1024×16}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.17,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=16,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(dwdeconv3-east)",
|
|
|
|
|
|
height=H16,
|
|
|
|
|
|
depth=D1024,
|
|
|
|
|
|
width=W16,
|
|
|
|
|
|
caption="",
|
|
|
|
|
|
),
|
|
|
|
|
|
to_UnPool(
|
|
|
|
|
|
"up3",
|
|
|
|
|
|
offset="(2,0,0)",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
n_filer=16,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(dwdeconv4-east)",
|
|
|
|
|
|
height=H32,
|
|
|
|
|
|
depth=D2048,
|
|
|
|
|
|
width=W16,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
caption="Deconv3",
|
|
|
|
|
|
captionshift=10,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv5",
|
|
|
|
|
|
s_filer="",
|
|
|
|
|
|
n_filer=1,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(up3-east)",
|
|
|
|
|
|
height=H32,
|
|
|
|
|
|
depth=D2048,
|
|
|
|
|
|
width=W1,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
caption="",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"dwdeconv6",
|
|
|
|
|
|
s_filer="{{2048×32}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.15,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=8,
|
|
|
|
|
|
offset="(0,0,0)",
|
|
|
|
|
|
to="(dwdeconv5-east)",
|
|
|
|
|
|
height=H32,
|
|
|
|
|
|
depth=D2048,
|
|
|
|
|
|
width=W8,
|
|
|
|
|
|
caption="",
|
|
|
|
|
|
),
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"outconv",
|
|
|
|
|
|
s_filer="{{2048×32}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.15,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=1,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
offset="(1.5,0,0)",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(dwdeconv6-east)",
|
|
|
|
|
|
height=H32,
|
|
|
|
|
|
depth=D2048,
|
|
|
|
|
|
width=W1,
|
2025-09-01 18:53:01 +02:00
|
|
|
|
caption="Deconv4",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
# to_connection("up2", "deconv2"),
|
|
|
|
|
|
# Output
|
|
|
|
|
|
to_Conv(
|
|
|
|
|
|
"out",
|
|
|
|
|
|
s_filer="{{2048×32}}",
|
2025-09-01 18:53:01 +02:00
|
|
|
|
zlabeloffset=0.15,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
n_filer=1,
|
2025-09-28 14:35:10 +02:00
|
|
|
|
offset="(1.5,0,0)",
|
2025-08-13 15:04:44 +02:00
|
|
|
|
to="(outconv-east)",
|
|
|
|
|
|
height=H32,
|
|
|
|
|
|
depth=D2048,
|
|
|
|
|
|
width=W1,
|
2025-09-01 18:53:01 +02:00
|
|
|
|
caption="Output",
|
2025-09-28 14:35:10 +02:00
|
|
|
|
captionshift=5,
|
2025-08-13 15:04:44 +02:00
|
|
|
|
),
|
|
|
|
|
|
# to_connection("deconv2", "out"),
|
|
|
|
|
|
to_end(),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
name = "subter_lenet_arch"
|
|
|
|
|
|
to_generate(arch, name + ".tex")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|