Files
mt/Deep-SAD-PyTorch/src/onnx_export.py
Jan Kowalczyk 8a5adc6360 tool updates
2025-08-13 14:15:15 +02:00

44 lines
1.6 KiB
Python

from pathlib import Path
import torch
import torch.onnx
from networks.subter_LeNet import SubTer_LeNet_Autoencoder
from networks.subter_LeNet_rf import SubTer_Efficient_AE
def export_model_to_onnx(model, filepath):
model.eval() # Set the model to evaluation mode
dummy_input = torch.randn(model.input_dim) # Create a dummy input tensor
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
filepath, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {0: "batch_size"}, # variable length axes
"output": {0: "batch_size"},
},
)
if __name__ == "__main__":
output_folder_path = Path("./onnx_models")
output_folder_path.mkdir(parents=True, exist_ok=True)
models_to_visualize = [
(
SubTer_LeNet_Autoencoder(rep_dim=32),
output_folder_path / "subter_lenet_ae.onnx",
),
(SubTer_Efficient_AE(rep_dim=32), output_folder_path / "subter_ef_ae.onnx"),
]
for model, output_path in models_to_visualize:
export_model_to_onnx(model, output_path)
print(f"Model has been exported to {output_path}")