102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
|
|
import torch
|
||
|
|
from thop import profile
|
||
|
|
|
||
|
|
from networks.subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||
|
|
from networks.subter_LeNet_rf import SubTer_Efficient_AE, SubTer_EfficientEncoder
|
||
|
|
|
||
|
|
# Configuration
|
||
|
|
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
||
|
|
BATCH_SIZE = 1
|
||
|
|
INPUT_SHAPE = (BATCH_SIZE, 1, 32, 2048)
|
||
|
|
|
||
|
|
|
||
|
|
def count_parameters(model, input_shape):
|
||
|
|
"""Count MACs and parameters for a model."""
|
||
|
|
model.eval()
|
||
|
|
with torch.no_grad():
|
||
|
|
input_tensor = torch.randn(input_shape)
|
||
|
|
macs, params = profile(model, inputs=(input_tensor,))
|
||
|
|
return {"MACs": macs, "Parameters": params}
|
||
|
|
|
||
|
|
|
||
|
|
def format_number(num: float) -> str:
|
||
|
|
"""Format large numbers with K, M, B, T suffixes."""
|
||
|
|
for unit in ["", "K", "M", "B", "T"]:
|
||
|
|
if abs(num) < 1000.0 or unit == "T":
|
||
|
|
return f"{num:3.2f}{unit}"
|
||
|
|
num /= 1000.0
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# Collect results per latent dimension
|
||
|
|
results = {} # dim -> dict of 8 values
|
||
|
|
for dim in LATENT_DIMS:
|
||
|
|
# Instantiate models for this latent dim
|
||
|
|
lenet_enc = SubTer_LeNet(rep_dim=dim)
|
||
|
|
eff_enc = SubTer_EfficientEncoder(rep_dim=dim)
|
||
|
|
lenet_ae = SubTer_LeNet_Autoencoder(rep_dim=dim)
|
||
|
|
eff_ae = SubTer_Efficient_AE(rep_dim=dim)
|
||
|
|
|
||
|
|
# Profile each
|
||
|
|
lenet_enc_stats = count_parameters(lenet_enc, INPUT_SHAPE)
|
||
|
|
eff_enc_stats = count_parameters(eff_enc, INPUT_SHAPE)
|
||
|
|
lenet_ae_stats = count_parameters(lenet_ae, INPUT_SHAPE)
|
||
|
|
eff_ae_stats = count_parameters(eff_ae, INPUT_SHAPE)
|
||
|
|
|
||
|
|
results[dim] = {
|
||
|
|
"lenet_enc_params": format_number(lenet_enc_stats["Parameters"]),
|
||
|
|
"lenet_enc_macs": format_number(lenet_enc_stats["MACs"]),
|
||
|
|
"eff_enc_params": format_number(eff_enc_stats["Parameters"]),
|
||
|
|
"eff_enc_macs": format_number(eff_enc_stats["MACs"]),
|
||
|
|
"lenet_ae_params": format_number(lenet_ae_stats["Parameters"]),
|
||
|
|
"lenet_ae_macs": format_number(lenet_ae_stats["MACs"]),
|
||
|
|
"eff_ae_params": format_number(eff_ae_stats["Parameters"]),
|
||
|
|
"eff_ae_macs": format_number(eff_ae_stats["MACs"]),
|
||
|
|
}
|
||
|
|
|
||
|
|
# Build LaTeX table with tabularx
|
||
|
|
header = (
|
||
|
|
"\\begin{table}[!ht]\n"
|
||
|
|
"\\centering\n"
|
||
|
|
"\\renewcommand{\\arraystretch}{1.15}\n"
|
||
|
|
"\\begin{tabularx}{\\linewidth}{lXXXXXXXX}\n"
|
||
|
|
"\\hline\n"
|
||
|
|
" & \\multicolumn{4}{c}{\\textbf{Encoders}} & "
|
||
|
|
"\\multicolumn{4}{c}{\\textbf{Autoencoders}} \\\\\n"
|
||
|
|
"\\cline{2-9}\n"
|
||
|
|
"\\textbf{Latent $z$} & "
|
||
|
|
"\\textbf{LeNet Params} & \\textbf{LeNet MACs} & "
|
||
|
|
"\\textbf{Eff. Params} & \\textbf{Eff. MACs} & "
|
||
|
|
"\\textbf{LeNet Params} & \\textbf{LeNet MACs} & "
|
||
|
|
"\\textbf{Eff. Params} & \\textbf{Eff. MACs} \\\\\n"
|
||
|
|
"\\hline\n"
|
||
|
|
)
|
||
|
|
|
||
|
|
rows = []
|
||
|
|
for dim in LATENT_DIMS:
|
||
|
|
r = results[dim]
|
||
|
|
row = (
|
||
|
|
f"{dim} & "
|
||
|
|
f"{r['lenet_enc_params']} & {r['lenet_enc_macs']} & "
|
||
|
|
f"{r['eff_enc_params']} & {r['eff_enc_macs']} & "
|
||
|
|
f"{r['lenet_ae_params']} & {r['lenet_ae_macs']} & "
|
||
|
|
f"{r['eff_ae_params']} & {r['eff_ae_macs']} \\\\"
|
||
|
|
)
|
||
|
|
rows.append(row)
|
||
|
|
|
||
|
|
footer = (
|
||
|
|
"\\hline\n"
|
||
|
|
"\\end{tabularx}\n"
|
||
|
|
"\\caption{Parameter and MAC counts for SubTer variants across latent dimensionalities.}\n"
|
||
|
|
"\\label{tab:subter_counts}\n"
|
||
|
|
"\\end{table}\n"
|
||
|
|
)
|
||
|
|
|
||
|
|
latex_table = header + "\n".join(rows) + "\n" + footer
|
||
|
|
|
||
|
|
print(latex_table)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|