Files
mt/thesis/third_party/PlotNeuralNet/deepsad/arch_lenet_encoder.py
Jan Kowalczyk 5ff56994c0 wip
2025-08-28 18:36:02 +02:00

118 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# subter_lenet_arch.py
# Requires running from inside the PlotNeuralNet repo, like: python3 ../subter_lenet_arch.py
import sys, argparse
sys.path.append("../") # import pycore from repo root
from pycore.tikzeng import *
parser = argparse.ArgumentParser()
parser.add_argument("--rep_dim", type=int, default=1024, help="latent size for FC")
args = parser.parse_args()
REP = int(args.rep_dim)
# Visual scales so the huge width doesn't dominate the figure
H32, H16, H8 = 26, 18, 12
D2048, D1024, D512 = 52, 36, 24
W1, W4, W8 = 1, 2, 4
arch = [
to_head(".."),
to_cor(),
to_begin(),
# --------------------------- ENCODER ---------------------------
# Input 1×32×2048 (caption carries H×W; s_filer is numeric)
to_Conv(
"input",
s_filer="{{2048×32}}",
n_filer=1,
offset="(0,0,0)",
to="(0,0,0)",
height=H32,
depth=D2048,
width=W1,
caption="Input",
),
# Conv1 (5x5, same): 1->8, 32×2048
to_Conv(
"conv1",
s_filer="{{1024×16}}",
n_filer=8,
offset="(2,0,0)",
to="(input-east)",
height=H32,
depth=D2048,
width=W8,
caption="Conv1",
),
# Pool1 2×2: 32×2048 -> 16×1024
# to_connection("input", "conv1"),
to_Pool(
"pool1",
offset="(0,0,0)",
to="(conv1-east)",
height=H16,
depth=D1024,
width=W8,
caption="",
),
# Conv2 (5x5, same): 8->4, stays 16×1024
to_Conv(
"conv2",
s_filer="{{512×8}}",
n_filer=4,
offset="(2,0,0)",
to="(pool1-east)",
height=H16,
depth=D1024,
width=W4,
caption="Conv2",
),
# Pool2 2×2: 16×1024 -> 8×512
# to_connection("pool1", "conv2"),
to_Pool(
"pool2",
offset="(0,0,0)",
to="(conv2-east)",
height=H8,
depth=D512,
width=W4,
caption="",
),
# FC -> rep_dim (use numeric n_filer)
to_fc(
"fc1",
n_filer="{{4×512×8}}",
offset="(2,0,0)",
to="(pool2-east)",
height=1.3,
depth=D512,
width=W1,
caption=f"FC",
),
# to_connection("pool2", "fc1"),
# --------------------------- LATENT ---------------------------
to_Conv(
"latent",
n_filer="",
s_filer="latent dim",
offset="(2,0,0)",
to="(fc1-east)",
height=H8 * 1.6,
depth=1.3,
width=W1,
caption=f"Latent Space",
),
to_end(),
]
def main():
name = "subter_lenet_arch"
to_generate(arch, name + ".tex")
if __name__ == "__main__":
main()