tool updates
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
from networks.mnist_LeNet import MNIST_LeNet_Autoencoder
|
||||
|
||||
from networks.subter_LeNet import SubTer_LeNet_Autoencoder
|
||||
from networks.subter_LeNet_rf import SubTer_Efficient_AE
|
||||
|
||||
|
||||
def export_model_to_onnx(model, filepath, input_shape=(1, 1, 28, 28)):
|
||||
def export_model_to_onnx(model, filepath):
|
||||
model.eval() # Set the model to evaluation mode
|
||||
dummy_input = torch.randn(input_shape) # Create a dummy input tensor
|
||||
dummy_input = torch.randn(model.input_dim) # Create a dummy input tensor
|
||||
torch.onnx.export(
|
||||
model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
@@ -23,13 +27,17 @@ def export_model_to_onnx(model, filepath, input_shape=(1, 1, 28, 28)):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the autoencoder model
|
||||
autoencoder = MNIST_LeNet_Autoencoder(rep_dim=32)
|
||||
output_folder_path = Path("./onnx_models")
|
||||
output_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Define the file path where the ONNX model will be saved
|
||||
onnx_file_path = "mnist_lenet_autoencoder.onnx"
|
||||
models_to_visualize = [
|
||||
(
|
||||
SubTer_LeNet_Autoencoder(rep_dim=32),
|
||||
output_folder_path / "subter_lenet_ae.onnx",
|
||||
),
|
||||
(SubTer_Efficient_AE(rep_dim=32), output_folder_path / "subter_ef_ae.onnx"),
|
||||
]
|
||||
|
||||
# Export the model
|
||||
export_model_to_onnx(autoencoder, onnx_file_path)
|
||||
|
||||
print(f"Model has been exported to {onnx_file_path}")
|
||||
for model, output_path in models_to_visualize:
|
||||
export_model_to_onnx(model, output_path)
|
||||
print(f"Model has been exported to {output_path}")
|
||||
|
||||
Reference in New Issue
Block a user