44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.onnx
|
|
|
|
from networks.subter_LeNet import SubTer_LeNet_Autoencoder
|
|
from networks.subter_LeNet_rf import SubTer_Efficient_AE
|
|
|
|
|
|
def export_model_to_onnx(model, filepath):
|
|
model.eval() # Set the model to evaluation mode
|
|
dummy_input = torch.randn(model.input_dim) # Create a dummy input tensor
|
|
torch.onnx.export(
|
|
model, # model being run
|
|
dummy_input, # model input (or a tuple for multiple inputs)
|
|
filepath, # where to save the model (can be a file or file-like object)
|
|
export_params=True, # store the trained parameter weights inside the model file
|
|
opset_version=11, # the ONNX version to export the model to
|
|
do_constant_folding=True, # whether to execute constant folding for optimization
|
|
input_names=["input"], # the model's input names
|
|
output_names=["output"], # the model's output names
|
|
dynamic_axes={
|
|
"input": {0: "batch_size"}, # variable length axes
|
|
"output": {0: "batch_size"},
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
output_folder_path = Path("./onnx_models")
|
|
output_folder_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
models_to_visualize = [
|
|
(
|
|
SubTer_LeNet_Autoencoder(rep_dim=32),
|
|
output_folder_path / "subter_lenet_ae.onnx",
|
|
),
|
|
(SubTer_Efficient_AE(rep_dim=32), output_folder_path / "subter_ef_ae.onnx"),
|
|
]
|
|
|
|
for model, output_path in models_to_visualize:
|
|
export_model_to_onnx(model, output_path)
|
|
print(f"Model has been exported to {output_path}")
|