Files
mt/thesis/third_party/PlotNeuralNet/deepsad/arch_lenet_decoder.py
2025-09-28 14:35:10 +02:00

130 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# subter_lenet_arch.py
# Requires running from inside the PlotNeuralNet repo, like: python3 ../subter_lenet_arch.py
import sys, argparse
sys.path.append("../") # import pycore from repo root
from pycore.tikzeng import *
parser = argparse.ArgumentParser()
parser.add_argument("--rep_dim", type=int, default=1024, help="latent size for FC")
args = parser.parse_args()
REP = int(args.rep_dim)
# Visual scales so the huge width doesn't dominate the figure
H32, H16, H8 = 26, 18, 12
D2048, D1024, D512 = 52, 36, 24
W1, W4, W8 = 1, 2, 4
arch = [
to_head(".."),
to_cor(),
to_begin(),
to_Conv(
"latent",
n_filer="",
s_filer="latent dim",
offset="(2,0,0)",
to="(0,0,0)",
height=H8 * 1.6,
depth=1.3,
width=W1,
caption=f"Latent Space",
),
# to_connection("fc1", "latent"),
# --------------------------- DECODER ---------------------------
# FC back to 16384
to_fc(
"fc3",
n_filer="{{4×512×8}}",
zlabeloffset=0.35,
offset="(2,-.5,0)",
to="(latent-east)",
height=1.3,
depth=D512,
width=W1,
caption=f"FC",
captionshift=20,
),
# to_connection("latent", "fc3"),
# Reshape to 4×8×512
to_UnPool(
"up1",
n_filer=4,
offset="(2.5,0,0)",
to="(fc3-east)",
height=H16,
depth=D1024,
width=W4,
caption="",
),
# Up ×2: 8×512 -> 16×1024 (we just draw a labeled box)
# DeConv1 (5×5, same): 4->8, 16×1024
to_Conv(
"deconv1",
s_filer="{{1024×16}}",
zlabeloffset=0.2,
n_filer=8,
offset="(0,0,0)",
to="(up1-east)",
height=H16,
depth=D1024,
width=W8,
caption="Deconv1",
),
# to_connection("fc3", "up1"),
# Up ×2: 16×1024 -> 32×2048
to_UnPool(
"up2",
offset="(2,0,0)",
n_filer=8,
to="(deconv1-east)",
height=H32,
depth=D2048,
width=W8,
caption="Deconv2",
captionshift=10,
),
# to_connection("deconv1", "up2"),
# DeConv2 (5×5, same): 8->1, 32×2048
to_Conv(
"deconv2",
s_filer="{{2048×32}}",
zlabeloffset=0.15,
n_filer=1,
offset="(0,0,0)",
to="(up2-east)",
height=H32,
depth=D2048,
width=W1,
caption="",
),
# to_connection("up2", "deconv2"),
# Output
to_Conv(
"out",
s_filer="{{2048×32}}",
zlabeloffset=0.15,
n_filer=1,
offset="(2,0,0)",
to="(deconv2-east)",
height=H32,
depth=D2048,
width=1.0,
caption="Output",
captionshift=5,
),
# to_connection("deconv2", "out"),
to_end(),
]
def main():
name = "subter_lenet_arch"
to_generate(arch, name + ".tex")
if __name__ == "__main__":
main()