# subter_lenet_arch.py # Requires running from inside the PlotNeuralNet repo, like: python3 ../subter_lenet_arch.py import sys, argparse sys.path.append("../") # import pycore from repo root from pycore.tikzeng import * parser = argparse.ArgumentParser() parser.add_argument("--rep_dim", type=int, default=1024, help="latent size for FC") args = parser.parse_args() REP = int(args.rep_dim) # Visual scales so the huge width doesn't dominate the figure H32, H16, H8 = 26, 18, 12 D2048, D1024, D512 = 52, 36, 24 W1, W4, W8 = 1, 2, 4 arch = [ to_head(".."), to_cor(), to_begin(), to_Conv( "latent", n_filer="", s_filer="latent dim", offset="(2,0,0)", to="(0,0,0)", height=H8 * 1.6, depth=1.3, width=W1, caption=f"Latent Space", ), # to_connection("fc1", "latent"), # --------------------------- DECODER --------------------------- # FC back to 16384 to_fc( "fc3", n_filer="{{4×512×8}}", offset="(2,0,0)", to="(latent-east)", height=1.3, depth=D512, width=W1, caption=f"FC", ), # to_connection("latent", "fc3"), # Reshape to 4×8×512 to_UnPool( "up1", offset="(2,0,0)", to="(fc3-east)", height=H16, depth=D1024, width=W4, caption="", ), # Up ×2: 8×512 -> 16×1024 (we just draw a labeled box) # DeConv1 (5×5, same): 4->8, 16×1024 to_Conv( "deconv1", s_filer="{{1024×16}}", n_filer=8, offset="(0,0,0)", to="(up1-east)", height=H16, depth=D1024, width=W8, caption="Deconv1", ), # to_connection("fc3", "up1"), # Up ×2: 16×1024 -> 32×2048 to_UnPool( "up2", offset="(2,0,0)", to="(deconv1-east)", height=H32, depth=D2048, width=W8, caption="", ), # to_connection("deconv1", "up2"), # DeConv2 (5×5, same): 8->1, 32×2048 to_Conv( "deconv2", s_filer="{{2048×32}}", n_filer=1, offset="(0,0,0)", to="(up2-east)", height=H32, depth=D2048, width=W1, caption="Deconv2", ), # to_connection("up2", "deconv2"), # Output to_Conv( "out", s_filer="{{2048×32}}", n_filer=1, offset="(2,0,0)", to="(deconv2-east)", height=H32, depth=D2048, width=1.0, caption="Output", ), # to_connection("deconv2", "out"), to_end(), ] def main(): name = "subter_lenet_arch" to_generate(arch, name + ".tex") if __name__ == "__main__": main()