From ea41207a83f7dca3bc5cd8cfc5c079a4ec071d23 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Tue, 30 Apr 2024 14:43:53 +0900 Subject: [PATCH] Fix to preserve domain and ir_version --- snc4onnx/__init__.py | 2 +- snc4onnx/onnx_network_combine.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/snc4onnx/__init__.py b/snc4onnx/__init__.py index 02be022..874321b 100644 --- a/snc4onnx/__init__.py +++ b/snc4onnx/__init__.py @@ -1,3 +1,3 @@ from snc4onnx.onnx_network_combine import combine, main -__version__ = '1.0.12' +__version__ = '1.0.13' diff --git a/snc4onnx/onnx_network_combine.py b/snc4onnx/onnx_network_combine.py index 9e9b2cc..13fee9a 100644 --- a/snc4onnx/onnx_network_combine.py +++ b/snc4onnx/onnx_network_combine.py @@ -230,11 +230,15 @@ def has_duplicates(seq): ## 1. ONNX load tmp_onnx_graphs = [] custom_domain_check_onnx_nodes = [] + max_ir_version: int = 0 if len(onnx_graphs) > 0: for onnx_graph in onnx_graphs: + domain: str = onnx_graph.domain + ir_version: int = onnx_graph.ir_version + max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version gs_graph = gs.import_onnx(onnx_graph) gs_graph.cleanup().toposort() - tmp_onnx_graphs.append(gs.export_onnx(gs_graph)) + tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})) custom_domain_check_onnx_nodes = \ custom_domain_check_onnx_nodes + \ [ @@ -243,9 +247,13 @@ def has_duplicates(seq): ] else: for onnx_path in input_onnx_file_paths: - gs_graph = gs.import_onnx(onnx.load(onnx_path)) + onnx_graph = onnx.load(onnx_path) + domain: str = onnx_graph.domain + ir_version: int = onnx_graph.ir_version + max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version + gs_graph = gs.import_onnx(onnx_graph) gs_graph.cleanup().toposort() - tmp_onnx_graphs.append(gs.export_onnx(gs_graph)) + tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})) custom_domain_check_onnx_graph = onnx.load(onnx_path) custom_domain_check_onnx_nodes = \ custom_domain_check_onnx_nodes + \ @@ -436,7 +444,7 @@ def has_duplicates(seq): # Cleaning src_gs_model.cleanup().toposort() - combined_model = gs.export_onnx(src_gs_model) + combined_model = gs.export_onnx(src_gs_model, do_type_check=False, **{'ir_version': max_ir_version}) ## Output of onnx files in the process of fusion if output_of_onnx_file_in_the_process_of_fusion and output_onnx_file_path: @@ -484,7 +492,7 @@ def has_duplicates(seq): replaced_output_names.append(tmp_replaced_output_name) gs_combined_model.cleanup().toposort() - combined_model = gs.export_onnx(gs_combined_model) + combined_model = gs.export_onnx(gs_combined_model, do_type_check=False, **{'ir_version': max_ir_version}) ## 4. Optimize try: