36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
|
|
import torch
|
||
|
|
import torch.onnx
|
||
|
|
from networks.mnist_LeNet import MNIST_LeNet_Autoencoder
|
||
|
|
|
||
|
|
|
||
|
|
def export_model_to_onnx(model, filepath, input_shape=(1, 1, 28, 28)):
|
||
|
|
model.eval() # Set the model to evaluation mode
|
||
|
|
dummy_input = torch.randn(input_shape) # Create a dummy input tensor
|
||
|
|
torch.onnx.export(
|
||
|
|
model, # model being run
|
||
|
|
dummy_input, # model input (or a tuple for multiple inputs)
|
||
|
|
filepath, # where to save the model (can be a file or file-like object)
|
||
|
|
export_params=True, # store the trained parameter weights inside the model file
|
||
|
|
opset_version=11, # the ONNX version to export the model to
|
||
|
|
do_constant_folding=True, # whether to execute constant folding for optimization
|
||
|
|
input_names=["input"], # the model's input names
|
||
|
|
output_names=["output"], # the model's output names
|
||
|
|
dynamic_axes={
|
||
|
|
"input": {0: "batch_size"}, # variable length axes
|
||
|
|
"output": {0: "batch_size"},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# Initialize the autoencoder model
|
||
|
|
autoencoder = MNIST_LeNet_Autoencoder(rep_dim=32)
|
||
|
|
|
||
|
|
# Define the file path where the ONNX model will be saved
|
||
|
|
onnx_file_path = "mnist_lenet_autoencoder.onnx"
|
||
|
|
|
||
|
|
# Export the model
|
||
|
|
export_model_to_onnx(autoencoder, onnx_file_path)
|
||
|
|
|
||
|
|
print(f"Model has been exported to {onnx_file_path}")
|