diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 859e422a6f0..2009aea813d 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -715,5 +715,5 @@ Classes :template: classtemplate.rst JitScalarType - torch.onnx.verification.GraphInfo - torch.onnx.verification.VerificationOptions + verification.GraphInfo + verification.VerificationOptions diff --git a/mypy.ini b/mypy.ini index b66464d21f5..5ab02361d61 100644 --- a/mypy.ini +++ b/mypy.ini @@ -165,9 +165,6 @@ ignore_missing_imports = True [mypy-tensorboard.*] ignore_missing_imports = True -[mypy-onnx.*] -ignore_missing_imports = True - [mypy-matplotlib.*] ignore_missing_imports = True @@ -301,5 +298,14 @@ ignore_missing_imports = True # Third party dependencies that are optional. # +[mypy-onnx.*] +ignore_missing_imports = True + +[mypy-onnxruntime.*] +ignore_missing_imports = True + +[mypy-onnxscript.*] +ignore_missing_imports = True + [mypy-redis] ignore_missing_imports = True \ No newline at end of file diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index b5822508dfa..c6e8cd47cfe 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -163,222 +163,5 @@ class TestDynamoExportAPI(common_utils.TestCase): ) -class TestONNXExportWithDynamo(common_utils.TestCase): - def test_args_normalization_with_no_kwargs(self): - exported_program = torch.export.export( - SampleModelTwoInputs(), - ( - torch.randn(1, 1, 2), - torch.randn(1, 1, 2), - ), - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2) - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_args_is_tensor_not_tuple(self): - exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),)) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, torch.randn(1, 1, 2) - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModel(), torch.randn(1, 1, 2), dynamo=True - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_args_normalization_with_kwargs(self): - exported_program = torch.export.export( - SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_args_normalization_with_empty_dict_at_the_tail(self): - exported_program = torch.export.export( - SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): - exported_program = torch.export.export( - SampleModelForDynamicShapes(), - ( - torch.randn(2, 2, 3), - torch.randn(2, 2, 3), - ), - dynamic_shapes={ - "x": { - 0: torch.export.Dim("customx_dim_0"), - 1: torch.export.Dim("customx_dim_1"), - 2: torch.export.Dim("customx_dim_2"), - }, - "b": { - 0: torch.export.Dim("customb_dim_0"), - 1: torch.export.Dim("customb_dim_1"), - 2: torch.export.Dim("customb_dim_2"), - }, - }, - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, - torch.randn(2, 2, 3), - b=torch.randn(2, 2, 3), - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, - "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, - }, - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): - exported_program = torch.export.export( - SampleModelForDynamicShapes(), - ( - torch.randn(2, 2, 3), - torch.randn(2, 2, 3), - ), - dynamic_shapes={ - "x": { - 0: torch.export.Dim("customx_dim_0"), - 1: torch.export.Dim("customx_dim_1"), - 2: torch.export.Dim("customx_dim_2"), - }, - "b": { - 0: torch.export.Dim("customb_dim_0"), - 1: torch.export.Dim("customb_dim_1"), - 2: torch.export.Dim("customb_dim_2"), - }, - }, - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, - torch.randn(2, 2, 3), - b=torch.randn(2, 2, 3), - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "x": [0, 1, 2], - "b": [0, 1, 2], - }, - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_dynamic_axes_supports_partial_dynamic_shapes(self): - exported_program = torch.export.export( - SampleModelForDynamicShapes(), - ( - torch.randn(2, 2, 3), - torch.randn(2, 2, 3), - ), - dynamic_shapes={ - "x": None, - "b": { - 0: torch.export.Dim("customb_dim_0"), - 1: torch.export.Dim("customb_dim_1"), - 2: torch.export.Dim("customb_dim_2"), - }, - }, - ) - onnx_program_from_new_exporter = torch.onnx.dynamo_export( - exported_program, - torch.randn(2, 2, 3), - b=torch.randn(2, 2, 3), - ) - onnx_program_from_old_exporter = torch.onnx.export( - SampleModelForDynamicShapes(), - (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), - dynamic_axes={ - "b": [0, 1, 2], - }, - dynamo=True, - ) - self.assertEqual( - onnx_program_from_new_exporter.model_proto, - onnx_program_from_old_exporter.model_proto, - ) - - def test_dynamic_shapes_hit_constraints_in_dynamo(self): - # SampleModelTwoInputs has constraints becuse of add of two inputs, - # so the two input shapes are related. - with self.assertRaisesRegex( - torch._dynamo.exc.UserError, - "Constraints violated", - ): - _ = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(2, 2, 3), torch.randn(2, 2, 3)), - dynamic_axes={ - "x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"}, - "b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"}, - }, - dynamo=True, - ) - - def test_saved_f_exists_after_export(self): - with common_utils.TemporaryFileName(suffix=".onnx") as path: - _ = torch.onnx.export( - SampleModel(), torch.randn(1, 1, 2), path, dynamo=True - ) - self.assertTrue(os.path.exists(path)) - - def test_raises_error_when_input_is_script_module(self): - class ScriptModule(torch.jit.ScriptModule): - def forward(self, x): - return x - - with self.assertRaisesRegex( - TypeError, - "Dynamo export does not support ScriptModule or ScriptFunction.", - ): - _ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True) - - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/exporter/README.md b/test/onnx/exporter/README.md new file mode 100644 index 00000000000..7ad65ca338b --- /dev/null +++ b/test/onnx/exporter/README.md @@ -0,0 +1 @@ +Directory for all ExportedProgram exporter logic. diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py new file mode 100644 index 00000000000..157ea1197b6 --- /dev/null +++ b/test/onnx/exporter/test_api.py @@ -0,0 +1,120 @@ +# Owner(s): ["module: onnx"] +"""Simple API tests for the ONNX exporter.""" + +from __future__ import annotations + +import os + +import torch +from torch.onnx._internal import exporter +from torch.testing._internal import common_utils + + +class SampleModel(torch.nn.Module): + def forward(self, x): + y = x + 1 + z = y.relu() + return (y, z) + + +class SampleModelTwoInputs(torch.nn.Module): + def forward(self, x, b): + y = x + b + z = y.relu() + return (y, z) + + +class SampleModelForDynamicShapes(torch.nn.Module): + def forward(self, x, b): + return x.relu(), b.sigmoid() + + +class TestExportAPIDynamo(common_utils.TestCase): + """Tests for the ONNX exporter API when dynamo=True.""" + + def test_args_normalization_with_no_kwargs(self): + onnx_program = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_args_normalization_with_kwargs(self): + onnx_program = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_args_normalization_with_empty_dict_at_the_tail(self): + onnx_program = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): + onnx_program = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, + "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, + }, + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): + onnx_program = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "x": [0, 1, 2], + "b": [0, 1, 2], + }, + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_dynamic_axes_supports_partial_dynamic_shapes(self): + onnx_program = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "b": [0, 1, 2], + }, + dynamo=True, + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + def test_saved_f_exists_after_export(self): + with common_utils.TemporaryFileName(suffix=".onnx") as path: + _ = torch.onnx.export( + SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True + ) + self.assertTrue(os.path.exists(path)) + + def test_export_supports_script_module(self): + class ScriptModule(torch.nn.Module): + def forward(self, x): + return x + + onnx_program = torch.onnx.export( + torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),), dynamo=True + ) + assert onnx_program + exporter.verify_onnx_program(onnx_program) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 0f5c2dc5a35..5433540aeb2 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -286,6 +286,25 @@ class TestPublicBindings(TestCase): # do not get imported by public code. private_allowlist = { "torch._inductor.codegen.cuda.cuda_kernel", + # TODO(#133647): Remove the onnx._internal entries after + # onnx and onnxscript are installed in CI. + "torch.onnx._internal.exporter", + "torch.onnx._internal.exporter._analysis", + "torch.onnx._internal.exporter._building", + "torch.onnx._internal.exporter._capture_strategies", + "torch.onnx._internal.exporter._compat", + "torch.onnx._internal.exporter._core", + "torch.onnx._internal.exporter._decomp", + "torch.onnx._internal.exporter._dispatching", + "torch.onnx._internal.exporter._fx_passes", + "torch.onnx._internal.exporter._ir_passes", + "torch.onnx._internal.exporter._isolated", + "torch.onnx._internal.exporter._onnx_program", + "torch.onnx._internal.exporter._registration", + "torch.onnx._internal.exporter._reporting", + "torch.onnx._internal.exporter._schemas", + "torch.onnx._internal.exporter._tensors", + "torch.onnx._internal.exporter._verification", "torch.onnx._internal.fx._pass", "torch.onnx._internal.fx.analysis", "torch.onnx._internal.fx.analysis.unsupported_nodes", diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 633dd112109..adb3bde6c6c 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,65 +1,5 @@ # mypy: allow-untyped-defs -from torch import _C -from torch._C import _onnx as _C_onnx -from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode - -from ._exporter_states import ExportTypes -from ._internal.onnxruntime import ( - is_onnxrt_backend_supported, - OrtBackend as _OrtBackend, - OrtBackendOptions as _OrtBackendOptions, - OrtExecutionProvider as _OrtExecutionProvider, -) -from ._type_utils import JitScalarType -from .errors import CheckerError # Backwards compatibility -from .utils import ( - _optimize_graph, - _run_symbolic_function, - _run_symbolic_method, - export, - export_to_pretty_string, - is_in_onnx_export, - register_custom_op_symbolic, - select_model_mode_for_export, - unregister_custom_op_symbolic, -) - - -from . import ( # usort: skip. Keep the order instead of sorting lexicographically - _deprecation, - errors, - symbolic_caffe2, - symbolic_helper, - symbolic_opset7, - symbolic_opset8, - symbolic_opset9, - symbolic_opset10, - symbolic_opset11, - symbolic_opset12, - symbolic_opset13, - symbolic_opset14, - symbolic_opset15, - symbolic_opset16, - symbolic_opset17, - symbolic_opset18, - symbolic_opset19, - symbolic_opset20, - utils, -) - - -from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import - DiagnosticOptions, - ExportOptions, - ONNXProgram, - ONNXProgramSerializer, - ONNXRuntimeOptions, - InvalidExportOptionsError, - OnnxExporterError, - OnnxRegistry, - dynamo_export, - enable_fake_mode, -) +from __future__ import annotations __all__ = [ @@ -115,6 +55,74 @@ __all__ = [ "is_onnxrt_backend_supported", ] +from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode + +from ._exporter_states import ExportTypes +from ._internal.onnxruntime import ( + is_onnxrt_backend_supported, + OrtBackend as _OrtBackend, + OrtBackendOptions as _OrtBackendOptions, + OrtExecutionProvider as _OrtExecutionProvider, +) +from ._type_utils import JitScalarType +from .errors import CheckerError # Backwards compatibility +from .utils import ( + _optimize_graph, + _run_symbolic_function, + _run_symbolic_method, + export_to_pretty_string, + is_in_onnx_export, + register_custom_op_symbolic, + select_model_mode_for_export, + unregister_custom_op_symbolic, +) + + +from . import ( # usort: skip. Keep the order instead of sorting lexicographically + _deprecation, + errors, + symbolic_caffe2, + symbolic_helper, + symbolic_opset7, + symbolic_opset8, + symbolic_opset9, + symbolic_opset10, + symbolic_opset11, + symbolic_opset12, + symbolic_opset13, + symbolic_opset14, + symbolic_opset15, + symbolic_opset16, + symbolic_opset17, + symbolic_opset18, + symbolic_opset19, + symbolic_opset20, + utils, +) + + +from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import + DiagnosticOptions, + ExportOptions, + ONNXProgram, + ONNXProgramSerializer, + ONNXRuntimeOptions, + InvalidExportOptionsError, + OnnxExporterError, + OnnxRegistry, + dynamo_export, + enable_fake_mode, +) + + +if TYPE_CHECKING: + import os + # Set namespace for exposed private names ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" @@ -137,6 +145,257 @@ producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION +def export( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...], + f: str | os.PathLike | None = None, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + opset_version: int | None = None, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool = False, + dynamo: bool = False, + # Dynamo only options + external_data: bool = True, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + report: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + fallback: bool = False, + # Deprecated options + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + do_constant_folding: bool = True, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, + **_: Any, # ignored options +) -> Any | None: + r"""Exports a model into ONNX format. + + Args: + model: The model to be exported. + args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Optional example keyword inputs. + export_params: If false, parameters (weights) will not be exported. + verbose: Whether to enable verbose logging. + input_names: names to assign to the input nodes of the graph, in order. + output_names: names to assign to the output nodes of the graph, in order. + opset_version: The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"] + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + } + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to model weights) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the user inputs are added as inputs. + + Set this to True if you intend to supply model weights at runtime. + Set it to False if the weights are static to allow for better optimizations + (e.g. constant folding) by backends/runtimes. + + dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. + external_data: Whether to save the model weights as an external data file. + This is required for models with large weights that exceed the ONNX file size limit (2GB). + When False, the weights are saved in the ONNX file with the model architecture. + dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to + :func:`torch.export.export` for more details. + report: Whether to generate a markdown report for the export process. + verify: Whether to verify the exported model using ONNX Runtime. + profile: Whether to profile the export process. + dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. + This is useful for debugging the exporter. + artifacts_dir: The directory to save the debugging artifacts like the report and the serialized + exported program. + fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + + training: Deprecated option. Instead, set the training mode of the model before exporting. + operator_export_type: Deprecated option. Only ONNX is supported. + do_constant_folding: Deprecated option. The exported graph is always optimized. + custom_opsets: Deprecated. + A dictionary: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + export_modules_as_functions: Deprecated option. + + Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + autograd_inlining: Deprecated. + Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + """ + if dynamo is True or isinstance(model, torch.export.ExportedProgram): + from torch.onnx._internal import exporter + + if isinstance(args, torch.Tensor): + args = (args,) + return exporter.export_compat( + model, + args, + f, + kwargs=kwargs, + export_params=export_params, + verbose=verbose, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + dynamic_shapes=dynamic_shapes, + report=report, + verify=verify, + profile=profile, + dump_exported_program=dump_exported_program, + artifacts_dir=artifacts_dir, + fallback=fallback, + ) + else: + from torch.onnx.utils import export + + export( + model, + args, + f, # type: ignore[arg-type] + kwargs=kwargs, + export_params=export_params, + verbose=verbose is True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + training=training, + operator_export_type=operator_export_type, + do_constant_folding=do_constant_folding, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + return None + + # TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. # Returns True iff ONNX logging is turned on. diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py new file mode 100644 index 00000000000..d08d1601221 --- /dev/null +++ b/torch/onnx/_internal/_lazy_import.py @@ -0,0 +1,38 @@ +"""Utility to lazily import modules.""" +# mypy: allow-untyped-defs +from __future__ import annotations + +import importlib +from typing import Any, TYPE_CHECKING + + +class _LazyModule: + """Lazily import a module.""" + + def __init__(self, module_name: str) -> None: + self._name = module_name + self._module: Any = None + + def __repr__(self) -> str: + return f"" + + def __getattr__(self, attr): + if self._module is None: + self._module = importlib.import_module(".", self._name) + return getattr(self._module, attr) + + +# Import the following modules during type checking to enable code intelligence features, +# such as auto-completion in tools like pylance, even when these modules are not explicitly +# imported in user code. +# NOTE: Add additional used imports here. +if TYPE_CHECKING: + import onnx + import onnxscript + + onnxscript_ir = onnxscript.ir + +else: + onnx = _LazyModule("onnx") + onnxscript = _LazyModule("onnxscript") + onnxscript_ir = _LazyModule("onnxscript.ir") diff --git a/torch/onnx/_internal/exporter/__init__.py b/torch/onnx/_internal/exporter/__init__.py new file mode 100644 index 00000000000..3bf21aa01dd --- /dev/null +++ b/torch/onnx/_internal/exporter/__init__.py @@ -0,0 +1,16 @@ +__all__ = [ + "ONNXRegistry", + "ONNXProgram", + "analyze", + "export", + "exported_program_to_ir", + "verify_onnx_program", + "export_compat", +] + +from ._analysis import analyze +from ._compat import export_compat +from ._core import export, exported_program_to_ir +from ._onnx_program import ONNXProgram +from ._registration import ONNXRegistry +from ._verification import verify_onnx_program diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py new file mode 100644 index 00000000000..43a65194826 --- /dev/null +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -0,0 +1,250 @@ +"""Compatibility analyzer for PyTorch models.""" + +# mypy: allow-untyped-defs +# flake8: noqa: B950 We do not need flake8 as it complains line length +from __future__ import annotations + +import dataclasses +import textwrap +import traceback +from collections import defaultdict +from typing import TYPE_CHECKING + +import onnxscript + +import torch +import torch._export.serde.schema +from torch.export import graph_signature +from torch.onnx._internal.exporter import _dispatching, _registration + + +if TYPE_CHECKING: + import torch.fx + + +@dataclasses.dataclass +class ModelInfo: + """Information about the model.""" + + parameter_count: defaultdict[torch.dtype, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + buffer_count: defaultdict[torch.dtype, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + fx_node_count: int = 0 + fx_node_op_count: defaultdict[str, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + fx_node_target_count: defaultdict[str, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field( + default_factory=list + ) + inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( + default_factory=dict + ) + outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( + default_factory=dict + ) + + +def _count_weights( + exported_program: torch.export.ExportedProgram, +) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]: + """Count the size of the parameters in the exported program.""" + + parameter_count: defaultdict[torch.dtype, int] = defaultdict(int) + buffer_count: defaultdict[torch.dtype, int] = defaultdict(int) + for parameter in exported_program.parameters(): + dtype = parameter.dtype + parameter_count[dtype] += parameter.numel() + + for buffer in exported_program.buffers(): + dtype = buffer.dtype + buffer_count[dtype] += buffer.numel() + + return parameter_count, buffer_count + + +def _format_model_info(model_info: ModelInfo) -> str: + """Format the information about the model.""" + lines = [ + textwrap.dedent( + f"""\ + PyTorch ONNX Conversion Analysis + + ## Model Information + + The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters). + Number of parameters per dtype: + ```python + {model_info.parameter_count} + ``` + Number of buffers per dtype: + ```python + {model_info.buffer_count} + ``` + """ + ), + "Inputs:", + *[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()], + "", + "Outputs:", + *[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()], + "", + f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:", + ] + for op, count in model_info.fx_node_op_count.items(): + lines.append(f"- `{op}`: {count}") + lines.append("\n") + lines.append("Of the call_function nodes, the counts of operators used are:\n") + sorted_targets = sorted( + model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True + ) + for target, count in sorted_targets: + lines.append(f"- `{target}`: {count}") + + lines.append("") + lines.append("## ONNX Conversion Information") + lines.append("") + + if model_info.dispatch_failures: + lines.append( + "The model contains operators the dispatcher could not find registered ONNX decompositions for. " + "This may be due to missing implementations, decompositions not registered " + "correctly, or a bug in the dispatcher." + ) + lines.append("") + lines.append("Errors grouped by operator:\n") + + target_to_nodes = defaultdict(list) + for node, _ in model_info.dispatch_failures: + target_to_nodes[str(node.target)].append(node) + + target_to_messages = {} + for node, message in model_info.dispatch_failures: + if str(node.target) not in target_to_messages: + target_to_messages[str(node.target)] = message + + for target, nodes in sorted( + target_to_nodes.items(), key=lambda x: x[0], reverse=True + ): + message = textwrap.indent( + f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`", + " ", + ) + lines.append(f"- `{target}`: {message}") + else: + lines.append("All operators in the model have registered ONNX decompositions.") + + return "\n".join(lines) + + +def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]: + """Get the input and output specs of the exported program.""" + + nodes: dict[str, torch.fx.Node] = { + node.name: node for node in exported_program.graph.nodes + } + user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind == graph_signature.InputKind.USER_INPUT + ] + user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind == graph_signature.OutputKind.USER_OUTPUT + ] + inputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + outputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + for spec in user_inputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + # FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type + inputs[name] = nodes[name].meta["tensor_meta"] + for spec in user_outputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + outputs[name] = nodes[name].meta["tensor_meta"] + return inputs, outputs + + +def _count_fx_targets( + exported_program: torch.export.ExportedProgram, +) -> defaultdict[str, int]: + """Count the number of targets for each node in the exported program.""" + fx_node_target_count: defaultdict[str, int] = defaultdict(int) + for node in exported_program.graph.nodes: + if node.op == "call_function": + fx_node_target_count[str(node.target)] += 1 + return fx_node_target_count + + +def analyze( + exported_program: torch.export.ExportedProgram, + registry: _registration.ONNXRegistry | None = None, + file=None, +) -> None: + """Analyze the compatibility of the exported program.""" + # Get basic information about the model + model_info = ModelInfo() + model_info.parameter_count, model_info.buffer_count = _count_weights( + exported_program + ) + model_info.fx_node_count = len(exported_program.graph.nodes) + model_info.fx_node_target_count = _count_fx_targets(exported_program) + inputs, outputs = _get_io_specs(exported_program) + model_info.inputs = inputs + model_info.outputs = outputs + + if registry is None: + # Trigger op registration + from onnxscript.function_libs.torch_lib import ops # noqa: F401 + + del ops + registry = _registration.ONNXRegistry.from_torchlib( + onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] + ) + + # Try to find ops for every node in the graph + for node in exported_program.graph.nodes: + model_info.fx_node_op_count[node.op] += 1 + if node.op == "call_function": + try: + onnx_function, message = _dispatching.dispatch(node, registry) + except Exception as e: + message = "Critical Error in dispatcher:\n" + formatted_exception = "\n".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + message += f"```pytb\n{formatted_exception}\n```\n" + onnx_function = None + if onnx_function is None: + model_info.dispatch_failures.append((node, message)) + + # Print the results + report = _format_model_info(model_info) + print(report, file=file, flush=True) + + +def compare_ops( + program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram +) -> tuple[set[str], set[str]]: + """Compare and get unique ops in two exported programs. + + Args: + program_a: The first exported program. + program_b: The second exported program. + + Returns: + A tuple of two sets, where the first set contains the unique ops in the first program + and the second set contains the unique ops in the second program. + """ + program_a_ops = set(_count_fx_targets(program_a)) + program_b_ops = set(_count_fx_targets(program_b)) + return program_a_ops - program_b_ops, program_b_ops - program_a_ops diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py new file mode 100644 index 00000000000..ddda83c3718 --- /dev/null +++ b/torch/onnx/_internal/exporter/_building.py @@ -0,0 +1,516 @@ +"""NOTES: + +We need a typing module that will handling Python to ONNX type promotion for use. +For example, if we have torch.ops.aten.add(Tensor, 1.0), we need to promote 1.0 +to the same type as Tensor. The same thing needs to work for +torch.ops.aten.add(1.0, Tensor) as well, which means we need a mechanism to` +""" + +# mypy: allow-untyped-defs +# mypy: disable-error-code=union-attr +from __future__ import annotations + +import copy +import inspect +import logging +from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union + +import onnxscript +from onnxscript import evaluator, ir +from onnxscript.ir import convenience as ir_convenience + +import torch +from torch.onnx._internal.exporter import _schemas, _tensors, errors + + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) + +# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes +ValidAttributeType = Union[ + ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None +] + +AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType] + + +# Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value +def _construct_named_inputs_and_attrs( + signature: _schemas.OpSignature, + args: Sequence[AllowedArgType], + kwargs: Mapping[str, AllowedArgType], +) -> tuple[dict[str, AllowedArgType], dict[str, ValidAttributeType]]: + """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. + + This function uses the OpSignature to determine which argument in args and kwargs corresponds to + which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are + stored in named_attrs. If an _optional input_ is not provided, it is filled with None. + + Args: + signature: The OpSignature for the node. + args: The positional arguments for the node. + kwargs: The keyword arguments for the node. + + Returns: + A tuple of two mappings: named_inputs and named_attrs. + + Raises: + ValueError: If a required parameter is not provided. + """ + # 1. Construct the (named_inputs, named_attrs) mapping based on (args, kwargs) and the signature. + # a. Loop over all parameters in the signature and args together + # b. Depending on param.is_input, Record named_inputs[param.name] = arg or named_attrs[param.name] = arg + # c. Handle kwargs as well + # d. Fill in None if the input is not provided + named_inputs = {} + named_attrs = {} + reversed_args_stack = list(reversed(args)) + for param in signature.params: + if isinstance(param, _schemas.Parameter): + # Handle inputs + if reversed_args_stack: + # First exhaust the positional arguments + if param.variadic: + # Handle variadic arguments + named_inputs[param.name] = tuple(args) + reversed_args_stack.clear() + else: + named_inputs[param.name] = reversed_args_stack.pop() # type: ignore[assignment] + elif param.name in kwargs: + named_inputs[param.name] = kwargs[param.name] # type: ignore[assignment] + elif param.required: + raise ValueError( + f"Required parameter '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional parameter '%s' is not provided. Added as None. Signature: %s", + param.name, + signature, + ) + named_inputs[param.name] = None # type: ignore[assignment] + else: + # Handle attributes + attribute: ValidAttributeType | ir.Attr + assert isinstance( + param, _schemas.AttributeParameter + ), f"Expected AttributeParameter, got {type(param)}" + if reversed_args_stack: + # First exhaust the positional arguments + attribute = reversed_args_stack.pop() # type: ignore[assignment] + elif param.name in kwargs: + attribute = kwargs[param.name] # type: ignore[assignment] + elif param.default is not None: + attribute = param.default + else: + attribute = None + + if attribute is None: + if param.required: + raise ValueError( + f"Required attribute '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional attribute '%s' is None. Dropped. Signature: %s", + param.name, + signature, + ) + continue + + if isinstance(attribute, ir.Attr): + # Turn the attribute from an default value into an actual parameter for the node + attr_copied = copy.copy(attribute) + # Make sure the name is the same as the parameter name and not the name of the default parameter + attr_copied.name = param.name + attribute = attr_copied + + if isinstance(attribute, int) and param.type == ir.AttributeType.FLOAT: + # Convert the attribute to float if needed. This happens in PyTorch + # where an attribute marked as float can be passed as an int. + attribute = float(attribute) + named_attrs[param.name] = attribute + return named_inputs, named_attrs # type: ignore[return-value] + + +def _resolve_parameter_dtypes( + signature: _schemas.OpSignature, named_inputs: Mapping[str, AllowedArgType] +) -> Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol]: + """Determine which parameter takes which type. + + Handle non-tensor input corner cases and type promotion. + + Requires: + All ir.Value in name_inputs should have type set. Their type should be + compatible with the type_constraint of the corresponding parameter in the signature. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. + + Returns: + A mapping of Constraint names to ir.TypeProtocol. + """ + # a. Create type_binding: dict[str, ir.TypeProtocol] + # b. Iterate over all named_inputs + # b0. Find the corresponding parameter in the signature + # b1. If the argument is a Python constant, skip. + # b2. If the argument is a ir.Value, Bind {constraint: arg.type}. + type_binding = {} + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance( + param, _schemas.Parameter + ), f"Expected Parameter, got {type(param)}" + if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)): + # Skip the Python constants because we do not know what dtype they should take yet + continue + elif isinstance(arg, ir.Value): + if arg.type is None: + # Skip the ir.Value if the type is not set + continue + # NOTE: We assume arg.type is compatible with the type_constraint + assert arg.type is not None, f"Expected type to be set for {arg}" + # TODO(justinchuby): Implement type promotion logic here. + type_binding[param.type_constraint] = arg.type + return type_binding + + +def _process_python_constants_and_sequences( + signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], + constant_farm: dict[ + tuple[ + bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + opset: onnxscript.values.Opset, +) -> dict[str, ir.Value | None]: + """Convert Python constants to Constant nodes and list to Sequence nodes based on the dtype information. + + The added constants will be replacing values in named_inputs in place. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. + type_binding: A mapping of Constraint names to ir.DataType. + constant_farm: A dictionary of {(py_value, ir.DataType): ir.Value} to store the deduplicated constants. + opset: The Opset to use for creating Constant nodes. + + Returns: + None + """ + # 3. Convert Python constants to Constant nodes based on the dtype information; + # construct sequences + # a. Iterate over all parameters in the signature the second time + # b. If the parameter is in to_resolve_type: + # - If param.constraint in type_binding, + # Get the constant from constant_farm (deduplicated); + # otherwise set named_inputs[param.name] = Constant(value, dtype=type_binding[param.constraint]) + # - Otherwise, set named_inputs[param.name] = Constant(value) + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance( + param, _schemas.Parameter + ), f"Expected Parameter, got {type(param)}" + + if isinstance(arg, ir.Value): + # TODO(justinchuby): Cast the ir.Value here if needed + continue + if ( + isinstance(arg, Sequence) + and len(arg) > 0 + and all(isinstance(val, ir.Value) for val in arg) + ): + # Skip the sequence of ir.Value. This is a variadic input or a Sequence input + # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants + # like `Max(0, ir.Value())` + # We need to convert the Python constants to Constant nodes + # NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float] + continue + # if param.variadic: + # # FXIME: Handle variadic inputs and sequence inputs differently + # raise NotImplementedError + # TODO: Find a way to recursively build constants. Maybe extract the logic out. + # FIXME: I am here + + assert isinstance( + param, _schemas.Parameter + ), f"Expected Parameter, got {type(param)}" + + if param.type_constraint in type_binding: + # A known dtype is available + dtype = type_binding[param.type_constraint].dtype + elif len(param.type_constraint.allowed_types) == 1: + # Only one type is allowed + dtype = next(iter(param.type_constraint.allowed_types)).dtype + else: + # No dtype information available. Infer from the Python constant + if isinstance(arg, bool): + dtype = ir.DataType.BOOL + elif isinstance(arg, float): + dtype = ir.DataType.FLOAT + elif isinstance(arg, int): + dtype = ir.DataType.INT64 + elif isinstance(arg, str): + dtype = ir.DataType.STRING + elif isinstance(arg, (tuple, list)) and all( + isinstance(val, int) for val in arg + ): + dtype = ir.DataType.INT64 + elif isinstance(arg, (tuple, list)) and any( + isinstance(val, float) for val in arg + ): + # NOTE: if any float is present, the dtype is float + dtype = ir.DataType.FLOAT + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + dtype = arg.dtype + elif arg is None: + dtype = ir.DataType.UNDEFINED + else: + raise TypeError( + f"Constant input '{arg}' of type '{type(arg)}' is not supported" + ) + + if arg is None: + constant_value = None + elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + # Deduplicate the constants + if isinstance(arg, (tuple, list)): + # Make the arg hashable + arg = tuple(arg) # noqa: PLW2901 + constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] + if constant_value is None: + constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] + else: + constant_value = opset.Constant(value=arg) + + named_inputs[param.name] = constant_value + return named_inputs # type: ignore[return-value] + + +def _construct_node( + signature: _schemas.OpSignature, + named_inputs: Mapping[str, ir.Value | None], + named_attrs: Mapping[str, ValidAttributeType], + opset: onnxscript.values.Opset, +) -> ir.Node: + """Construct the node with the inputs and attributes. + + Variadic inputs are flattened. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. When we + do not have the schema of an operator, we do not know the names of + the inputs, in which case the names can be anything because they + are not used in this function. The data structure is passed in for + consistency with the other functions. + named_attrs: The mapping of attribute names to their values. + """ + inputs: list[Any] = [] + # Flatten variadic inputs + for value in named_inputs.values(): + if isinstance(value, Sequence): + inputs.extend(value) + else: + inputs.append(value) + + # Construct and filter out None attributes + attributes = [ + attr + for attr in ir_convenience.convert_attributes(named_attrs) + if attr.value is not None + ] + outputs = [_tensors.SymbolicTensor(opset) for _ in signature.outputs] + return ir.Node( + signature.domain, + signature.name, + inputs=inputs, + attributes=attributes, + outputs=outputs, + ) + + +class OpRecorder(evaluator.Evaluator): + """An onnxscript Evaluator that captures the graph into torchscript.""" + + def __init__( + self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value] + ): + self.nodes: list[ir.Node] = [] + self.opset = opset + self.functions: dict[ir.OperatorIdentifier, onnxscript.OnnxFunction] = {} + self.constant_farm = constant_farm + + def _call_op( + self, + op_signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + named_attrs: dict[str, ValidAttributeType], + ) -> Sequence[_tensors.SymbolicTensor]: + """Record nodes for the given opschema and arguments. + + Args: + op_signature: The OpSchema containing the node signature. + named_inputs: The mapping of parameter names to their arguments. + named_attrs: The mapping of attribute names to their values. + """ + type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) + try: + converted_named_inputs = _process_python_constants_and_sequences( + op_signature, named_inputs, type_binding, self.constant_farm, self.opset + ) + except Exception as e: + raise errors.GraphConstructionError( + f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " + f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." + ) from e + + try: + self.nodes.append( + node := _construct_node( + op_signature, converted_named_inputs, named_attrs, self.opset + ) + ) + except Exception as e: + raise errors.GraphConstructionError( + f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. " + f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, " + f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." + ) from e + return node.outputs # type: ignore[return-value] + + def eval( + self, + schema: onnx.defs.OpSchema, + args: Sequence[AllowedArgType], # type: ignore[override] + kwargs: Mapping[str, AllowedArgType], + ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor]: + try: + op_signature = _schemas.OpSignature.from_opschema(schema) + named_inputs, named_attrs = _construct_named_inputs_and_attrs( + op_signature, args, kwargs + ) + # TODO(justinchuby): Handle cast + if schema.name == "CastLike": + assert len(named_inputs) == 2 + # Skip CastLike if the input and output types are the same + src_input = named_inputs["input"] + target_type = named_inputs["target_type"] + + if ( + isinstance(src_input, ir.Value) + and isinstance(target_type, ir.Value) + and src_input.dtype is not None + and target_type.dtype is not None + ): + # dtypes are available + if src_input.dtype == target_type.dtype: + # Same type. No cast needed + return src_input # type: ignore[return-value] + else: + # Create a Cast node + return self.opset.Cast(src_input, to=target_type.dtype) # type: ignore[union-attr,return-value] + + outputs = self._call_op(op_signature, named_inputs, named_attrs) + if len(outputs) == 1: + return outputs[0] + return outputs + except Exception as e: + raise errors.GraphConstructionError( + f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}." + ) from e + + def eval_function( # type: ignore[override] + self, + function: onnxscript.OnnxFunction, + args: Sequence[AllowedArgType], + kwargs: Mapping[str, AllowedArgType], + ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int: + try: + # Special cases for handling IsScalar and Rank + if function.name == "IsScalar": + if len(args) != 1: + raise TypeError( + f"Expected 1 positional argument for function '{function}', got {len(args)}." + ) + if isinstance(args[0], _tensors.SymbolicTensor): + if args[0].rank is not None: + return args[0].rank == 0 + else: + # Fall to call add_function_call + pass + elif isinstance(args[0], Sequence): + return False + else: + # Python constants are scalars + return True + if function.name == "Rank": + if len(args) != 1: + raise TypeError( + f"Expected 1 positional argument for function '{function}', got {len(args)}." + ) + if isinstance(args[0], _tensors.SymbolicTensor): + if args[0].rank is not None: + return args[0].rank + else: + # Fall to call add_function_call + pass + elif isinstance(args[0], Sequence): + if all(isinstance(arg, (int, float)) for arg in args[0]): + return 1 + else: + # Fall to call add_function_call + pass + else: + # Python constants are scalars + return 0 + + # NOTE: signature is written to function in the registration process + # TODO: Upstream signature to ONNX Function + if hasattr(function, "signature"): + op_signature = function.signature + else: + op_signature = _schemas.OpSignature.from_function( + function, function.function_ir.domain, function.name + ) + + named_inputs, named_attrs = _construct_named_inputs_and_attrs( + op_signature, args, kwargs + ) + + # NOTE: We need to call traceable functions after the _construct_named_inputs_and_attrs + # call because it will filter out the unexpected kwargs for us. + if function.traceable: + # Trace the function call instead of adding the function as a node + return function.function(**named_inputs, **named_attrs) + + outputs = self._call_op(op_signature, named_inputs, named_attrs) + + self.functions[(function.function_ir.domain, function.name, "")] = function + if len(outputs) == 1: + return outputs[0] + return outputs + except Exception as e: + try: + source_file = inspect.getsourcefile(function.function) + _, lineno = inspect.getsourcelines(function.function) + except Exception: + source_file = lineno = None + raise errors.GraphConstructionError( + f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}." + + f" The function is defined at '{source_file}:{lineno}'." + if source_file + else "" + ) from e diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py new file mode 100644 index 00000000000..dc511491d6b --- /dev/null +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -0,0 +1,335 @@ +"""Strategies for capturing ExportedPrograms.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import abc +import dataclasses +import datetime +import pathlib +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch._export import converter as _torchscript_converter +from torch.utils import _pytree + + +if TYPE_CHECKING: + import os + + +def _verbose_printer(verbose: bool | None) -> Callable[..., None]: + """Prints messages based on `verbose`.""" + if verbose is False: + return lambda *_, **__: None + return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) + + +def _take_first_line(text: str) -> str: + """Take the first line of a text.""" + lines = text.split("\n", maxsplit=1) + first_line = lines[0] + if len(lines) > 1: + first_line += "[...]" + return first_line + + +@dataclasses.dataclass +class Result: + exported_program: torch.export.ExportedProgram | None + strategy: str + exception: Exception | None = None + + @property + def success(self) -> bool: + return self.exported_program is not None + + +class CaptureStrategy(abc.ABC): + """Strategy for capturing a module as ExportedProgram. + + To use a strategy, create an instance and call it with the model, args, kwargs, and dynamic_shapes. + Example:: + + strategy = TorchExportStrategy(verbose=True) + result = strategy(model, args, kwargs, dynamic_shapes) + """ + + def __init__( + self, + *, + verbose: bool = False, + dump: bool = False, + artifacts_dir: str | os.PathLike = ".", + timestamp: str | None = None, + ): + """Initialize the strategy. + + Args: + verbose: Whether to print verbose messages. + dump: Whether to dump the intermediate artifacts to a file. + """ + self._verbose_print = _verbose_printer(verbose) + self._dump = dump + self._artifacts_dir = pathlib.Path(artifacts_dir) + self._timestamp = timestamp or datetime.datetime.now().strftime( + "%Y-%m-%d_%H-%M-%S-%f" + ) + + def __call__( + self, + model: torch.nn.Module | torch.jit.ScriptFunction, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + dynamic_shapes, + ) -> Result: + self._enter(model) + if kwargs is None: + kwargs = {} + try: + exported_program = self._capture(model, args, kwargs, dynamic_shapes) + except Exception as e: + self._failure(model, e) + return Result( + exported_program=None, + strategy=self.__class__.__name__, + exception=e, + ) + self._success(model) + return Result(exported_program, strategy=self.__call__.__name__) + + @abc.abstractmethod + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + raise NotImplementedError + + def _enter(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: + return + + def _success(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: + return + + def _failure( + self, model: torch.nn.Module | torch.jit.ScriptFunction, e: Exception + ) -> None: + return + + +class TorchExportStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export`..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export`... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export`... ❌" + ) + + +class TorchExportNonStrictStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False + ) + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ❌" + ) + + +class JitTraceConvertStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + del dynamic_shapes # Unused + + flattened_args, spec = _pytree.tree_flatten((args, kwargs)) + flattened_args = tuple(flattened_args) + + # Since torch.jit.trace only accepts Tensors as inputs, we filter + # out non-Tensor arguments and reconstruct the arguments after entering + # the WrappedModel. + tensor_placeholder = object() + non_tensor_args = [ + arg if not isinstance(arg, torch.Tensor) else tensor_placeholder + for arg in flattened_args + ] + tensor_args = tuple( + arg for arg in flattened_args if isinstance(arg, torch.Tensor) + ) + + class WrappedModel(torch.nn.Module): + """Wrap the model so that it takes flattened arguments.""" + + def __init__(self, m): + super().__init__() + self.model = m + + def forward(self, *_args): + # Take the non-Tensor arguments list as a starting point and + # replace the tensor_placeholder with the actual tensor arguments + # from _args. + reconstructed_flattened_args = non_tensor_args.copy() + _args_iter = iter(_args) + for i, arg in enumerate(reconstructed_flattened_args): + if arg is tensor_placeholder: + reconstructed_flattened_args[i] = next(_args_iter) + # Unflatten the arguments and kwargs to pass to the model. + unflattened_args, unflattened_kwargs = _pytree.tree_unflatten( + reconstructed_flattened_args, spec + ) + results = self.model(*unflattened_args, **unflattened_kwargs) + if not isinstance(results, tuple): + results = (results,) + flattened_results, _ = _pytree.tree_flatten(results) + if len(flattened_results) == 1: + return flattened_results[0] + return tuple(flattened_results) + + jit_model = torch.jit.trace( + WrappedModel(model), + example_inputs=tensor_args, + check_trace=False, + strict=False, + ) + if self._dump: + program_path = self._artifacts_dir / f"onnx_export_{self._timestamp}.pt" + try: + torch.jit.save(jit_model, program_path) + except Exception as e: + self._verbose_print( + f"Failed to save Torch Script model due to an error: {e}" + ) + else: + self._verbose_print( + f"Torch Script model has been saved to '{program_path}'." + ) + return _torchscript_converter.TS2EPConverter( + jit_model, flattened_args + ).convert() + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with Torch Script..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with Torch Script... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with Torch Script... ❌" + ) + + +class LegacyDynamoStrategy(CaptureStrategy): + """Strategy implemented by the ONNX team using internal dynamo APIs and custom fx passes.""" + + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + # NOTE: Import here to prevent circular dependency + from torch.onnx._internal.fx import diagnostics, passes + + graph_module, _ = torch._dynamo.export( + model, + tracing_mode="symbolic", + dynamic_shapes=dynamic_shapes, + )( + *args, + **kwargs, + ) + torch._dynamo.reset() + + diagnostic_context = diagnostics.DiagnosticContext( + "torch.onnx.export", + torch.__version__, + ) + + flattened_args, _ = _pytree.tree_flatten((args, kwargs)) + flattened_args = tuple(flattened_args) + + # ONNX does not support views and mutations. + # Functionalize to get a semantically equivalent graph without mutations. + graph_module = passes.Functionalize( + diagnostic_context, + graph_module, + enable_dynamic_axes=bool(dynamic_shapes), + ).run(*flattened_args) + + # Input mutations are detected and distilled after `Functionalize` pass. + # Remove them since ONNX inference does not need them. + graph_module = passes.RemoveInputMutation(diagnostic_context, graph_module).run( + *flattened_args + ) + + # Use torch.export to recapture the GraphModule into an ExportedProgram. + return torch.export.export(graph_module, flattened_args) + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with internal Dynamo apis..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ❌" + ) + + +CAPTURE_STRATEGIES = ( + TorchExportStrategy, + TorchExportNonStrictStrategy, + JitTraceConvertStrategy, + LegacyDynamoStrategy, +) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py new file mode 100644 index 00000000000..642f768d728 --- /dev/null +++ b/torch/onnx/_internal/exporter/_compat.py @@ -0,0 +1,225 @@ +"""Compatibility functions for the torch.onnx.export API.""" + +# mypy: allow-untyped-defs +# mypy: disable-error-code=attr-defined +from __future__ import annotations + +import inspect +import logging +from typing import Any, Mapping, Sequence, TYPE_CHECKING + +import onnx + +import torch +import torch.export +from torch.onnx._internal.exporter import _core, _onnx_program + + +if TYPE_CHECKING: + import os + +logger = logging.getLogger(__name__) + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") + + +def _from_dynamic_axes_to_dynamic_shapes( + model, + dynamic_axes=None, + input_names: Sequence[str] | None = None, +) -> dict[str, Any] | None: + """ + + dynamic_axes examples: + (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} + (2) dynamic_axes = {"x": [0], "y": [1]} + + these will be converted to dynamic_shapes respectively: + (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} + (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names + + """ + # https://github.com/pytorch/pytorch/pull/128371 + # 1. The function does not need to provide dynamic_shapes to torch.export.export + if dynamic_axes is None: + return None + + if input_names is None: + input_names = [] + + sig = _signature(model) + if len(input_names) > len(sig.parameters): + raise ValueError( + f"Number of input names ({len(input_names)}) should not be greater than " + f"the number of model inputs ({len(sig.parameters)})" + ) + input_names_to_model_inputs = {} + for idx, param_name in enumerate(sig.parameters): + if idx < len(input_names): + input_names_to_model_inputs[input_names[idx]] = param_name + else: + input_names_to_model_inputs[param_name] = param_name + + # NOTE: torch.export.export does not support input names assignment, + # so we need to map input names to model inputs to create dynamic_shapes + # for the exported program + dynamic_shapes_to_exported_program = {} + for input_name, axes in dynamic_axes.items(): + # input_name can be either from inptu_names or from the model inputs + if input_name not in input_names_to_model_inputs: + raise ValueError( + f"dynamix axis: {input_name} is not found in the input names: {input_names}" + ) + model_input_name = input_names_to_model_inputs[input_name] + if isinstance(axes, dict): + dynamic_shapes_to_exported_program[model_input_name] = { + k: torch.export.Dim(v) for k, v in axes.items() + } + elif isinstance(axes, list): + dynamic_shapes_to_exported_program[model_input_name] = { + k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes + } + else: + raise TypeError( + f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" + ) + # torch.export.export needs static dim to present in dynamic_shapes + # for all input tensors, so we need to add them with None + for input_name in sig.parameters: + if input_name not in dynamic_shapes_to_exported_program: + dynamic_shapes_to_exported_program[input_name] = None # type: ignore[assignment] + + return dynamic_shapes_to_exported_program + + +def _get_torch_export_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, +) -> tuple[tuple[Any, ...], dict[str, Any] | None]: + """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" + if not kwargs and args and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + return args, kwargs + + +def _convert_version(path: str | os.PathLike, opset_version: int) -> None: + """Convert the ONNX file to a specific version.""" + model = onnx.load(path, load_external_data=False) + model = onnx.version_converter.convert_version(model, opset_version) + onnx.save(model, path) + + +def export_compat( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...], + f: str | os.PathLike | None = None, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + opset_version: int | None = None, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + keep_initializers_as_inputs: bool = False, + external_data: bool = True, + report: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + fallback: bool = False, + **_, +) -> _onnx_program.ONNXProgram | None: + if isinstance(model, torch.export.ExportedProgram): + # We the model is already exported program, so the args, kwargs, and dynamic_shapes + # are not used + dynamic_shapes = dynamic_shapes or {} + else: + args, kwargs = _get_torch_export_args(args, kwargs) + if dynamic_shapes is None and dynamic_axes is not None: + dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( + model, dynamic_axes, input_names + ) + + should_convert_version = False + + try: + onnx_program = _core.export( + model, + args, + kwargs, + registry=None, + dynamic_shapes=dynamic_shapes, + input_names=input_names, + output_names=output_names, + profile=profile, + report=report, + verify=verify, + dump_exported_program=dump_exported_program, + artifacts_dir=artifacts_dir, + verbose=verbose, + ) + + if f is not None: + # Always save the initializers as external data to reduce the size of the ONNX file + onnx_program.save( + f, + include_initializers=export_params, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + ) + if ( + opset_version is not None + and opset_version != onnx_program.model.opset_imports.get("") + ): + should_convert_version = True + + except Exception as e: + if fallback: + if verbose is not False: + print( + "[torch.onnx] Falling back to legacy torch.onnx.export due " + f"to the following error: {e}", + ) + torch.onnx.utils.export( + model, # type: ignore[arg-type] + args, + f, # type: ignore[arg-type] + kwargs=kwargs, + export_params=export_params, + input_names=input_names, + output_names=output_names, + opset_version=17, # TODO(justinchuby): Hard coded to 17 for now + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + ) + onnx_program = None + if opset_version is None: + opset_version = 18 + if opset_version != 17: + should_convert_version = True + else: + raise + + if f is not None and should_convert_version: + assert opset_version is not None + if verbose is not False: + print( + f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..." + ) + _convert_version(f, opset_version) + + return onnx_program diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py new file mode 100644 index 00000000000..3d28a0544a8 --- /dev/null +++ b/torch/onnx/_internal/exporter/_core.py @@ -0,0 +1,1344 @@ +# mypy: allow-untyped-defs +# flake8: noqa: B950 We do not need flake8 as it complains line length +from __future__ import annotations + +import ctypes +import datetime +import inspect +import itertools +import logging +import operator +import pathlib +import textwrap +import traceback +import typing +from typing import Any, Callable, Literal, Sequence + +import onnx + +import onnxscript +import onnxscript.evaluator +import onnxscript.function_libs +import onnxscript.function_libs.torch_lib +import onnxscript.function_libs.torch_lib.registration +from onnxscript import ir +from onnxscript.ir import convenience as ir_convenience + +import torch +import torch.fx +from torch.export import graph_signature +from torch.onnx._internal.exporter import ( + _analysis, + _building, + _capture_strategies, + _dispatching, + _fx_passes, + _ir_passes, + _isolated, + _onnx_program, + _registration, + _reporting, + _tensors, + _verification, + errors, +) + + +if typing.TYPE_CHECKING: + import os + + import numpy as np + + +# Define utilities to convert PyTorch data types so users do not need to specify manually +_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.COMPLEX128, + torch.complex64: ir.DataType.COMPLEX64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, +} +_BLUE = "\033[96m" +_END = "\033[0m" + +_STEP_ONE_ERROR_MESSAGE = textwrap.dedent( + f"""\ + Failed to export the model with torch.export. {_BLUE}This is step 1/2{_END} of exporting the model to ONNX. Next steps: + - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information. + - Debug `torch.export.export` and summit a PR to PyTorch. + - Create an issue in the PyTorch GitHub repository against the {_BLUE}*torch.export*{_END} component and attach the full error stack as well as reproduction scripts.""" +) + +_STEP_TWO_ERROR_MESSAGE = textwrap.dedent( + f"""\ + Failed to convert the exported program to an ONNX model. {_BLUE}This is step 2/2{_END} of exporting the model to ONNX. Next steps: + - If there is a missing ONNX function, implement it and register it to the registry. + - If there is an internal error during ONNX conversion, debug the error and summit a PR to PyTorch. + - Save the ExportedProgram as a pt2 file and create an error report with `export(..., report=True)`. Create an issue in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component. Attach the pt2 model and the error report.""" +) + +logger = logging.getLogger(__name__) + + +def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: + return _TORCH_DTYPE_TO_ONNX[dtype] + + +class TorchTensor(ir.Tensor): + def __init__(self, tensor: torch.Tensor, name: str | None = None): + # Pass the tensor as the raw data to ir.Tensor's constructor + super().__init__( + tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name + ) + + def __array__(self, dtype: Any = None) -> np.ndarray: + # numpy() calls __array__ in ir.Tensor + if self.dtype == ir.DataType.BFLOAT16: + return self.raw.view(torch.uint16).__array__(dtype) + if self.dtype in { + ir.DataType.FLOAT8E4M3FN, + ir.DataType.FLOAT8E4M3FNUZ, + ir.DataType.FLOAT8E5M2, + ir.DataType.FLOAT8E5M2FNUZ, + }: + # TODO: Use ml_dtypes + return self.raw.view(torch.uint8).__array__(dtype) + return self.raw.__array__(dtype) + + def tobytes(self) -> bytes: + # Implement tobytes to support native PyTorch types so we can use types like bloat16 + # Reading from memory directly is also more efficient because + # it avoids copying to a NumPy array + tensor = self.raw.detach().cpu().contiguous() + return bytes( + (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( + tensor.data_ptr() + ) + ) + + +# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L56C1-L62C19 +# class InputKind(Enum): +# USER_INPUT = auto() +# PARAMETER = auto() +# BUFFER = auto() +# CONSTANT_TENSOR = auto() +# CUSTOM_OBJ = auto() +# TOKEN = auto() + +# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L89C1-L96C19 +# class OutputKind(Enum): +# USER_OUTPUT = auto() +# LOSS_OUTPUT = auto() +# BUFFER_MUTATION = auto() +# GRADIENT_TO_PARAMETER = auto() +# GRADIENT_TO_USER_INPUT = auto() +# USER_INPUT_MUTATION = auto() +# TOKEN = auto() + + +def _set_shape_types( + values: Sequence[ir.Value], + meta_vals: Sequence[torch.Tensor], + complex_to_float: bool = True, +) -> None: + if not isinstance(meta_vals, Sequence): + logger.warning( + "Expected meta_vals to be a sequence, but got %s. There may be an internal error.", + meta_vals, + ) + meta_vals = (meta_vals,) + for value, meta_val in zip(values, meta_vals): + _set_shape_type(value, meta_val, complex_to_float=complex_to_float) + + +def _set_shape_type( + value: ir.Value, + meta_val: torch.Tensor | tuple[torch.Tensor], + complex_to_float: bool, +) -> None: + # TODO: Consider using meta["tensor_meta"] for this? Would it be faster? + if isinstance(meta_val, tuple): + logger.warning("Setting shape and type of tensors is not supported yet") + if isinstance(meta_val, torch.Tensor): + # FIXME: Consider shape for complex values + dims = [] + for dim in meta_val.shape: + if isinstance(dim, int): + dims.append(dim) + else: + dims.append(str(dim.node)) + value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype) + if complex_to_float: + if meta_val.dtype == torch.complex64: + value.dtype = ir.DataType.FLOAT + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) + elif meta_val.dtype == torch.complex128: + value.dtype = ir.DataType.DOUBLE + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) + + value.shape = ir.Shape(dims) + elif isinstance(meta_val, (int, torch.SymInt)): + # aten::sym_size output is a int, not a tensor, which stands + # for the size of one dim. We treat it as a scalar. + value.dtype = ir.DataType.INT64 + value.shape = ir.Shape([]) + elif isinstance(meta_val, (bool, torch.SymBool)): + value.dtype = ir.DataType.BOOL + value.shape = ir.Shape([]) + elif isinstance(meta_val, (float, torch.SymFloat)): + value.dtype = ir.DataType.FLOAT + value.shape = ir.Shape([]) + else: + pass + + +def _get_qualified_module_name(cls: Any) -> str: + if isinstance(cls, str): + return cls + module = cls.__module__ + if module is None or module == str.__class__.__module__: + return cls.__name__ + return module + "." + cls.__name__ + + +def _get_node_namespace(node: torch.fx.Node) -> tuple[str, list[str], list[str]]: + """Get the namespace and scope of the node. + + Example:: + + { + 'L__self__': ('', ), + 'L__self___avgpool': ('avgpool', ) + } + + Will yield + + namespace: ": torchvision.models.resnet.ResNet/avgpool: torch.nn.modules.pooling.AdaptiveAvgPool2d/node_name: node_target" + class_hierarchy: ["torchvision.models.resnet.ResNet", "torch.nn.modules.pooling.AdaptiveAvgPool2d", ] + name_scopes: ["", "avgpool", ] + + Args: + node: The node to get the namespace and scope of. + + Returns: + (namespace, class_hierarchy, name_scope) + """ + nn_module_stack = node.meta.get("nn_module_stack") + logger.debug("%s", nn_module_stack) + if nn_module_stack is None: + logger.warning( + "nn_module_stack not found for node '%s'. Skip adding metadata...", + node.name, + ) + return f"{node.name}: {node.target}", [str(node.target)], [node.name] + namespaces = [] + class_hierarchy = [] + name_scopes = [] + for name, nn_module in nn_module_stack.values(): + name_scopes.append(name) + nn_module_name = _get_qualified_module_name(nn_module) + class_hierarchy.append(nn_module_name) + namespaces.append(f"{name}: {_get_qualified_module_name(nn_module)}") + namespaces.append(f"{node.name}: {node.target}") + class_hierarchy.append(str(node.target)) + name_scopes.append(node.name) + + return "/".join(namespaces), class_hierarchy, name_scopes + + +def _set_node_metadata(fx_node: torch.fx.Node, ir_node: ir.Node) -> None: + """Adds namespace and other node metadata to the ONNX node.""" + namespace, class_hierarchy, name_scopes = _get_node_namespace(fx_node) + ir_node.metadata_props["namespace"] = namespace + ir_node.metadata_props["pkg.torch.onnx.class_hierarchy"] = repr(class_hierarchy) + ir_node.metadata_props["pkg.torch.onnx.name_scopes"] = repr(name_scopes) + ir_node.metadata_props["pkg.torch.onnx.fx_node"] = str(fx_node.format_node()) + ir_node.metadata_props["pkg.torch.onnx.stack_trace"] = fx_node.meta.get( + "stack_trace", "" + ) + + +def _handle_getitem_node( + node: torch.fx.Node, node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] +) -> ir.Value: + """Handle a getitem node. + + Add the input value it is getting to the mapping, then return the value. + + There are two cases for this node: + 1. The output is a Sequence (traced), we can simply get the value from the sequence + 2. The output is produced by a SplitToSequence node, we need to get the value from the sequence value + This function only handles the first case + """ + assert len(node.all_input_nodes) == 1 + source = node.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + assert isinstance( + source_outputs, Sequence + ), f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}" + index = typing.cast(int, node.args[1]) + value = source_outputs[index] + # Save the getitem value to the values mapping to in case + # it is one of the graph outputs + node_name_to_values[node.name] = value + # Rename the name of value with the getitem name. + value.name = node.name + return value + + +def _handle_call_function_node( + graph: ir.Graph, + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], +) -> None: + """Handle a call_function node. + + Args: + graph: The ONNX graph at construction. + node: The FX node to translate. + node_name_to_values: A mapping of FX node names to their produced ir.Value. + """ + if node.target == operator.getitem: + _handle_getitem_node(node, node_name_to_values) + # Add op to the graph + op = str(node.target) + fx_inputs, attributes, input_names, output_names = _get_inputs_and_attributes(node) + inputs: list[ir.Value | None] = [] + for i, input_ in enumerate(fx_inputs): + if input_ is None: + inputs.append(None) + elif hasattr(input_, "name"): + if isinstance(input_, torch.fx.Node) and input_.target == operator.getitem: + actual_input = _handle_getitem_node(input_, node_name_to_values) + inputs.append(actual_input) + else: + value = node_name_to_values[input_.name] + assert not isinstance(value, Sequence) + inputs.append(value) + else: + attributes[f"arg_{i}"] = input_ + + outputs = [ir.Value(name=name) for name in output_names] + if len(outputs) > 1: + _set_shape_types(outputs, node.meta["val"], complex_to_float=False) + node_name_to_values[node.name] = outputs + else: + _set_shape_type(outputs[0], node.meta["val"], complex_to_float=False) + node_name_to_values[node.name] = outputs[0] + ir_node = ir.Node( + "pkg.torch.ops", + op, + inputs, + attributes=ir_convenience.convert_attributes(attributes), + outputs=outputs, + name=node.name, + ) + ir_node.meta["node"] = node + ir_node.metadata_props["pkg.torch.onnx.input_names"] = repr(input_names) + # Record the nn.Module stack for the node + _set_node_metadata(node, ir_node) + + graph.append(ir_node) + + +def _convert_fx_arg_to_onnx_arg( + arg, node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] +) -> Any: + """Convert an FX argument to an ONNX compatible argument. + + This function + - Converts a torch dtype to an integer + - Converts a torch device/memory_format/layout to a string + - Converts a torch.fx.Node to an ir.Value + - Converts a sequence of torch.fx.Node to a sequence of ir.Value + """ + if arg is None: + # None arguments are not modified because when the arg is an ONNX input + # we need to preserve the None value; when the arg is an ONNX attribute, + # we want to drop the value. + # The actual dropping of a None attribute value is done by OpRecorder + return None + if hasattr(arg, "name"): + if isinstance(arg, torch.fx.Node) and arg.target == operator.getitem: + source = arg.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + if isinstance(source_outputs, Sequence): + # If the node is getting an input from another node, get the actual value the node is retrieving + return _handle_getitem_node(arg, node_name_to_values) + else: + # `source_outputs` is a sequence(tensor()) value and we need to + # use SequenceAt to get the value. This is handled by torchlib + pass + # If the input is a node, get the value from the mapping + return node_name_to_values[arg.name] + if isinstance(arg, (list, tuple)): + return [_convert_fx_arg_to_onnx_arg(elem, node_name_to_values) for elem in arg] + if isinstance(arg, (torch.device, torch.memory_format, torch.layout)): + return str(arg) + if isinstance(arg, torch.dtype): + return _torch_dtype_to_onnx_dtype(arg) + # Maybe a Python value + return arg + + +def _get_onnxscript_opset(opset_version: int) -> onnxscript.values.Opset: + return onnxscript.values.Opset("", opset_version) + + +def _handle_call_function_node_with_lowering( + model: ir.Model, + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + constant_farm: dict[Any, ir.Value], + registry: _registration.ONNXRegistry, + opset: onnxscript.values.Opset, +) -> None: + if node.target == operator.getitem: + source = node.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + if isinstance(source_outputs, Sequence): + _handle_getitem_node(node, node_name_to_values) + return + else: + # `source_outputs` is a sequence(tensor()) value and we need to + # use SequenceAt to get the value. This is handled by torchlib + pass + + # Find the matching ONNX overload for the node + # NOTE: Create different registries for different ONNX opset versions + # TODO: Log the message here to expose false positives + onnx_function, message = _dispatching.dispatch(node, registry) + + if onnx_function is None: + # TODO(justinchuby): Fall back to ATen op or do something else? + raise errors.DispatchError( + f"No ONNX function found for {node.target!r}. Failure message: {message}" + ) + + # Map FX inputs to ONNX inputs and fill optional inputs. + # torch_args and torch_kwargs are for op-level validation + fx_args = node.args + fx_kwargs = node.kwargs + + # Replace the input FX nodes with ONNX values + onnx_args = [ + _convert_fx_arg_to_onnx_arg(input_, node_name_to_values) for input_ in fx_args + ] + + onnx_kwargs = {} + for key, value in fx_kwargs.items(): + onnx_kwargs[key] = _convert_fx_arg_to_onnx_arg(value, node_name_to_values) + if key == "dtype" and onnx_kwargs[key] is None: + # Set dtype to -1 if it is None + onnx_kwargs[key] = -1 + + with onnxscript.evaluator.default_as( + tracer := _building.OpRecorder(opset, constant_farm) + ): + try: + outputs = onnx_function(*onnx_args, **onnx_kwargs) + except Exception as e: + raise errors.GraphConstructionError( + f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" + ) from e + + # NOTE: Instead of using the output names from node.target._schema, + # we always use the index if there are more than one outputs so the + # names can be programmatically reconstructed. This is useful for + # comparing values from the ONNX graph with those from the FX graph. + # + # When there are multiple outputs, the output names will be + # node_name__0, node_name__1, etc. + if isinstance(outputs, Sequence): + _set_shape_types(outputs, node.meta["val"], complex_to_float=True) + node_name_to_values[node.name] = outputs + for i, output in enumerate(outputs): + output.name = f"{node.name}__{i}" + else: + _set_shape_type(outputs, node.meta["val"], complex_to_float=True) + node_name_to_values[node.name] = outputs + outputs.name = node.name + + for ir_node in tracer.nodes: + ir_node.meta["node"] = node + # Record the nn.Module stack for the node + _set_node_metadata(node, ir_node) + + # Add the traced nodes to the graph + model.graph.extend(tracer.nodes) + # Add the defined functions to the model + for identifier, onnxscript_function in tracer.functions.items(): + if identifier in model.functions: + continue + # TODO: Get IR function directly when onnxscript is updated + proto = onnxscript_function.to_function_proto() + ir_function = ir.serde.deserialize_function(proto) + model.functions[identifier] = ir_function + if ir_function.domain not in model.opset_imports: + # FIXME: Record the correct opset version of the function + model.opset_imports[ir_function.domain] = 1 + + +def _handle_placeholder_node( + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + *, + lower: str, + opset: onnxscript.values.Opset, +) -> None: + # Placeholder nodes are user inputs + # We need to create a new tensor for each user input + # and add it to the graph's inputs + name = node.name + input_ = _tensors.SymbolicTensor(opset, name=name) + input_.meta["node"] = node + _set_shape_type(input_, node.meta["val"], complex_to_float=lower != "none") + node_name_to_values[name] = input_ + # The inputs will be added to the graph later + + +def _add_nodes( + exported_program: torch.export.ExportedProgram, + model: ir.Model, + lower: Literal["at_conversion", "post_conversion", "none"], + registry: _registration.ONNXRegistry, +) -> dict[str, ir.Value | Sequence[ir.Value]]: + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] = {} + constant_farm: dict[Any, ir.Value] = {} + opset = _get_onnxscript_opset(registry.opset_version) + for node in exported_program.graph.nodes: + logger.debug( + "%s", (node.name, node.args, node.target, node.op, node.type, node.kwargs) + ) + try: + if node.op == "placeholder": + _handle_placeholder_node( + node, + node_name_to_values, + lower=lower, + opset=opset, + ) + elif node.op == "call_function": + if lower == "at_conversion": + _handle_call_function_node_with_lowering( + model, + node, + node_name_to_values, + constant_farm, + registry=registry, + opset=opset, + ) + else: + # No lowering + _handle_call_function_node(model.graph, node, node_name_to_values) + except Exception as e: + raise errors.OnnxConversionError( + f"Error when translating node {node.format_node()}. See the stack trace for more information." + ) from e + return node_name_to_values + + +def _torch_version_integer() -> int: + return int(torch.__version__.replace(".", "").split("dev")[0]) + + +def _get_inputs_and_attributes( + node: torch.fx.Node, +) -> tuple[list[torch.fx.Node | None], dict[str, Any], list[str], list[str]]: + """Find and Fill in the not provided kwargs with default values. + + Returns: + (inputs, attributes, input_names, output_names) + """ + if inspect.isbuiltin(node.target) or isinstance(node.target, str): + inputs = list(node.args) + return inputs, {}, [], [node.name] # type: ignore[return-value] + + # The target should be an ATen operator now + assert hasattr( + node.target, "_schema" + ), f"The target should be an ATen operator now, but node target {node.target} has no schema" + node_schema: torch.FunctionSchema = node.target._schema + + # This function assumes the order of arguments in FX op is the + # same as the order of arguments in TorchScript op. + inputs: list[Any] = [] # type: ignore[no-redef] + input_names: list[str] = [] + attributes: dict[str, Any] = {} + + if inspect.isbuiltin(node.target): + inputs = list(node.args) + else: + for arg, schema_arg in zip(node.args, node_schema.arguments): + if arg is None or isinstance(arg, torch.fx.Node): + inputs.append(arg) + input_names.append(schema_arg.name) + elif isinstance(arg, Sequence) and all( + elem is None or isinstance(elem, torch.fx.Node) for elem in arg + ): + inputs.extend(arg) + input_names.extend([schema_arg.name] * len(arg)) + elif isinstance(arg, torch.device): + attributes[schema_arg.name] = str(arg) + elif isinstance(arg, torch.dtype): + attributes[schema_arg.name] = _torch_dtype_to_onnx_dtype(arg) + else: + attributes[schema_arg.name] = arg + for schema_arg in node_schema.arguments: + if schema_arg.name not in node.kwargs: + continue + kwarg = node.kwargs[schema_arg.name] + if schema_arg.name in { + "layout", + "device", + "requires_grad", + "memory_format", + "implicit", + } or isinstance(kwarg, torch.device): + attr = str(kwarg) + elif isinstance(kwarg, torch.dtype): + attr = _torch_dtype_to_onnx_dtype(kwarg) # type: ignore[assignment] + else: + attr = kwarg # type: ignore[assignment] + + attributes[schema_arg.name] = attr + + output_names = [f"{node.name}_{output.name}" for output in node_schema.returns] + + return inputs, attributes, input_names, output_names # type: ignore[return-value] + + +def _maybe_start_profiler(should_profile: bool) -> Any: + if should_profile: + import pyinstrument # type: ignore[import-not-found] + + profiler = pyinstrument.Profiler(async_mode="disabled") + profiler.start() + return profiler + return None + + +def _maybe_stop_profiler_and_get_result(profiler) -> str | None: + if profiler is None: + return None + profiler.stop() + return profiler.output_text(unicode=True) + + +def _format_exception(e: Exception) -> str: + """Format the full traceback as Python would show it.""" + return "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) + + +def _summarize_exception_stack(e: BaseException) -> str: + """Format the exception stack by showing the text of each exception.""" + causes = [e] + while e.__cause__ is not None: + causes.append(e.__cause__) + e = e.__cause__ + return ( + "\n\n## Exception summary\n\n" + + "⬆️\n".join([f"{type(e)}: {e}\n" for e in reversed(causes)]) + + "\n(Refer to the full stack trace above for more information.)" + ) + + +def _format_exceptions_for_all_strategies( + results: list[_capture_strategies.Result], +) -> str: + """Format all the exceptions from the capture strategies.""" + return "\n".join( + [ + f"# ⚠️ Errors from strategy '{result.strategy}': -----------------------\n\n" + f"{_format_exception(result.exception)}\n" + for result in results + if result.exception is not None + ] + ) + + +def exported_program_to_ir( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry | None = None, + lower: Literal["at_conversion", "post_conversion", "none"] = "at_conversion", +) -> ir.Model: + """Convert an exported program to an ONNX IR model. + + Reference: + - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html + + Args: + exported_program: The exported program to convert. + lower: Whether to lower the graph to core ONNX operators. + at_conversion: Lower whe translating the FX graph to ONNX IR. + post_conversion: Use an IR pass to lower the graph. + none: Do not lower the graph. + registry: The registry of all ONNX Script decomposition. + """ + if registry is None: + # Trigger op registration + from onnxscript.function_libs.torch_lib import ops # noqa: F401 + + del ops + registry = _registration.ONNXRegistry.from_torchlib( + onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] + ) + if lower != "none": + exported_program = _prepare_exported_program_for_export( + exported_program, registry=registry + ) + return _exported_program_to_onnx_program( + exported_program, registry=registry, lower=lower + ).model + + +def _prepare_exported_program_for_export( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry, +) -> torch.export.ExportedProgram: + """Decompose and apply pre-export transformations to the exported program.""" + # Decompose the graph given the implemented torch ops in ONNX + exported_program = _fx_passes.decompose_with_registry(exported_program, registry) + + graph_module = exported_program.graph_module + # Include explicit type promotion nodes + graph_module = _fx_passes.insert_type_promotion_nodes(graph_module) + graph_module = _fx_passes.remove_assertion_nodes(graph_module) + # TODO(justinchuby): Reassigning the graph module to save some runtime. + # If this does not work, we need to retrace the module with torch.export + exported_program._graph_module = graph_module + return exported_program + + +def _exported_program_to_onnx_program( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry, + lower: Literal["at_conversion", "post_conversion", "none"] = "at_conversion", +) -> _onnx_program.ONNXProgram: + """Convert an exported program to an ONNX Program. + + The exported_program field in the returned ONNXProgram is one that is after + decompositions have been applied. + + Reference: + - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html + + Args: + exported_program: The exported program to convert. The exported program + should be the one that is after decompositions have been applied. + lower: Whether to lower the graph to core ONNX operators. + at_conversion: Lower whe translating the FX graph to ONNX IR. + post_conversion: Use an IR pass to lower the graph. + none: Do not lower the graph. + registry: The registry of all ONNX Script decomposition. + """ + model = ir.Model( + graph=ir.Graph( + [], + [], + nodes=[], + opset_imports={ + "": registry.opset_version, + }, + name="main_graph", + metadata_props={ + "pkg.torch.export.ExportedProgram.graph_signature": str( + exported_program.graph_signature + ), + "pkg.torch.export.ExportedProgram.range_constraints": str( + exported_program.range_constraints + ), + }, + ), + ir_version=9, + producer_name="torch", + producer_version=torch.__version__, + ) + + if lower == "none": + # Add the opset import for the torch ops + model.opset_imports["pkg.torch.ops"] = _torch_version_integer() + # NOTE: Function domains are added when translating nodes when lower="at_conversion" + + # 1. Add all nodes to the graph and create a dictionary of values + values = _add_nodes(exported_program, model, lower=lower, registry=registry) + + # 2. Add user inputs and all parameters/buffers to the graph. + # Since the node names and the tensor names are different, we need to rename + # the nodes to match the tensor names later. For now we will just use the node names. + user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind == graph_signature.InputKind.USER_INPUT + ] + non_user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind != graph_signature.InputKind.USER_INPUT + ] + + for spec in itertools.chain(user_inputs, non_user_inputs): + # Put the user inputs first and then the parameters/buffers + if isinstance(spec.arg, graph_signature.ConstantArgument): + logger.debug("Skipping constant argument %s", spec.arg) + continue + value_name = spec.arg.name + input_kind = spec.kind + persistent = spec.persistent + value = values[value_name] + + assert not isinstance( + value, Sequence + ), f"Input '{value_name}' should not be a sequence. This is unexpected." + + value.metadata_props[ + "pkg.torch.export.graph_signature.InputSpec.kind" + ] = input_kind.name + value.metadata_props[ + "pkg.torch.export.graph_signature.InputSpec.persistent" + ] = str(persistent) + + if input_kind == graph_signature.InputKind.USER_INPUT: + # Add only user inputs to the graph + # Subsequent passes can decide if they want to add initializers as inputs + model.graph.inputs.append(value) + else: + model.graph.initializers[value_name] = value + + # 3. Add user outputs to the graph and assign metadata to all outputs + user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind == graph_signature.OutputKind.USER_OUTPUT + ] + non_user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind != graph_signature.OutputKind.USER_OUTPUT + ] + for spec in itertools.chain(user_outputs, non_user_outputs): + if isinstance(spec.arg, graph_signature.ConstantArgument): + logger.warning("Skipping constant argument %s", spec.arg) + continue + value_name = spec.arg.name + output_kind = spec.kind + value = values[value_name] + + if not isinstance(value, (ir.Value, Sequence)): + raise TypeError( + f"Output '{value_name}' should be an ir.Value. Actual type is '{type(value)}': {value!r}. " + "This may be due to an incorrect implementation of the ONNX function that produced this output." + ) + + # The output value may be a sequence, meaning the operator has multiple outputs + _values = (value,) if not isinstance(value, Sequence) else value + + if len(_values) > 1: + logger.warning( + "Model output '%s' has multiple values: %s (output spec: %s). Please make sure this is expected.", + value_name, + _values, + spec, + ) + + for value in _values: + value.metadata_props[ + "pkg.torch.export.graph_signature.OutputSpec.kind" + ] = output_kind.name + if output_kind == graph_signature.OutputKind.USER_OUTPUT: + model.graph.outputs.append(value) + + # 4. Rename the initializers to match the tensor names + for name, param_name in itertools.chain( + exported_program.graph_signature.inputs_to_parameters.items(), + exported_program.graph_signature.inputs_to_buffers.items(), + exported_program.graph_signature.inputs_to_lifted_tensor_constants.items(), + ): + initializer = model.graph.initializers.pop(name) + initializer.name = param_name + # Record the original name so users can search the metadata and correspond + # with the FX graph + initializer.metadata_props["pkg.torch.onnx.original_node_name"] = name + model.graph.initializers[param_name] = initializer + + # 5. Add initializers to the graph + # ExportedProgram stores parameters and buffers in state_dict, + # but non_persistent_buffers and lifted_tensor_constants are not there + # so we need to get them from the name_* apis. + for name, torch_tensor in itertools.chain( + exported_program.named_parameters(), + exported_program.named_buffers(), + exported_program.constants.items(), + ): + initializer = model.graph.initializers.get(name) # type: ignore[assignment] + if initializer is None: + logger.warning("Tensor '%s' is not one of the initializers", name) + continue + if not isinstance(torch_tensor, torch.Tensor): + raise NotImplementedError( + f"Tensor '{name}' should be a torch.Tensor. Actual type is '{type(torch_tensor)}': {torch_tensor!r}. " + "This is unexpected and not yet supported." + ) + ir_tensor = TorchTensor(torch_tensor, name=name) + initializer.const_value = ir_tensor + _set_shape_type( + initializer, + torch_tensor, + complex_to_float=lower != "none", + ) + + # TODO: Decide if we should keep mutated buffers as inputs/outputs + + return _onnx_program.ONNXProgram(model, exported_program) + + +def _verbose_printer(verbose: bool | None) -> Callable[..., None]: + """Prints messages based on `verbose`.""" + if verbose is False: + return lambda *_, **__: None + return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) + + +def export( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.fx.GraphModule + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + *, + registry: _registration.ONNXRegistry | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + report: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + verbose: bool | None = None, +) -> _onnx_program.ONNXProgram: + """Export a PyTorch model to ONNXProgram. + + Args: + model: The model to export. This can be a PyTorch nn.Module or an ExportedProgram. + args: The arguments to pass to the model. + kwargs: The keyword arguments to pass to the model. + registry: The registry of all ONNX decompositions. + dynamic_shapes: Dynamic shapes in the graph. + input_names: If provided, rename the inputs. + output_names: If provided, rename the outputs. + report: Whether to generate an error report if the export fails. + verify: Whether to verify the ONNX model after exporting. + profile: Whether to profile the export process. When report is True, + the profile result will be saved in the report. Otherwise, the profile + result will be printed. + dump_exported_program: Whether to save the exported program to a file. + artifacts_dir: The directory to save the exported program and error reports. + verbose: Whether to print verbose messages. If None (default), some messages will be printed. + + Returns: + The ONNXProgram with the exported IR graph. + + Raises: + TorchExportError: If the export process fails with torch.export. + OnnxConversionError: If the ExportedProgram to ONNX translation fails. + """ + # Set up the error reporting facilities + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") + profiler = _maybe_start_profiler(profile) + + # Create the artifacts directory if it does not exist + artifacts_dir = pathlib.Path(artifacts_dir) + if report or profile or dump_exported_program: + artifacts_dir.mkdir(parents=True, exist_ok=True) + + verbose_print = _verbose_printer(verbose) + export_status = _reporting.ExportStatus() + failed_results: list[_capture_strategies.Result] = [] + + program: torch.export.ExportedProgram | None = None + # Step 1: Export the model with torch.export.export if the model is not already an ExportedProgram + if isinstance(model, torch.export.ExportedProgram): + program = model + export_status.torch_export = True + else: + # Convert an nn.Module to an ExportedProgram + # Try everything 🐰 (all paths for getting an ExportedProgram) + # When input is a JIT module, the last strategy will succeed so it is handled + result: _capture_strategies.Result | None = None + for strategy_class in _capture_strategies.CAPTURE_STRATEGIES: + strategy = strategy_class( # type: ignore[abstract] + verbose=verbose is not False, # Treat None as verbose + dump=dump_exported_program, + artifacts_dir=artifacts_dir, + timestamp=timestamp, + ) + result = strategy(model, args, kwargs, dynamic_shapes=dynamic_shapes) + + # Record the status + if strategy_class is _capture_strategies.TorchExportStrategy: + export_status.torch_export = result.success + elif strategy_class is _capture_strategies.TorchExportNonStrictStrategy: + export_status.torch_export_non_strict = result.success + elif strategy_class is _capture_strategies.JitTraceConvertStrategy: + export_status.torch_jit = result.success + + if result.exported_program is not None: + program = result.exported_program + break + else: + failed_results.append(result) + + assert result is not None + if result.exported_program is None: + # If all strategies fail, produce an error report and raise the first error + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + try: + _reporting.create_torch_export_error_report( + report_path, + _format_exceptions_for_all_strategies(failed_results), + export_status=export_status, + profile_result=profile_result, + ) + except Exception as e_report: + verbose_print( + f"Failed to save error report due to an error: {e_report}" + ) + else: + report_path = None + + first_error = failed_results[0].exception + assert first_error is not None + + # NOTE: We only throw the torch.export (first) exception because we want to + # focus on the torch.export.export error. Errors from other strategies like + # torch.jit.trace is due to the fallback and can be confusing to users. + # We save all errors in the error report. + raise errors.TorchExportError( + _STEP_ONE_ERROR_MESSAGE + + ( + f"\nError report has been saved to '{report_path}'." + if report + else "" + ) + + _summarize_exception_stack(first_error) + ) from first_error + + assert program is not None + + if dump_exported_program: + verbose_print("Dumping ExportedProgram because `dump_exported_program=True`...") + program_path = artifacts_dir / f"onnx_export_{timestamp}.pt2" + try: + torch.export.save(program, program_path) + except Exception as e: + verbose_print(f"Failed to save ExportedProgram due to an error: {e}") + else: + verbose_print(f"ExportedProgram has been saved to '{program_path}'.") + + # Step 2: Convert the exported program to an ONNX model + verbose_print("Translate the graph into ONNX...") + + # Step 2a: Decompose the exported program and insert type promotion nodes + try: + # Build the ONNX function registry + if registry is None: + # Trigger op registration + from onnxscript.function_libs.torch_lib import ops + + del ops + registry = _registration.ONNXRegistry.from_torchlib( + onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] + ) + + # Process the exported program to run decompositions and type promotions etc. + decomposed_program = _prepare_exported_program_for_export( + program, registry=registry + ) + except Exception as e: + export_status.onnx_translation = False + verbose_print("Translate the graph into ONNX... ❌") + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + # Run the analysis to get the error report + try: + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + program, + export_status=export_status, + profile_result=profile_result, + registry=registry, + ) + except Exception: + logger.exception("Failed to save report due to an error.") + else: + report_path = None + + raise errors.OnnxConversionError( + _STEP_TWO_ERROR_MESSAGE + + (f"\nError report has been saved to '{report_path}'." if report else "") + + _summarize_exception_stack(e) + ) from e + + # Step 2b: Translate the decomposed program to ONNX and produce ONNXProgram + if report or profile: + pre_decomp_unique_ops, post_decomp_unique_ops = _analysis.compare_ops( + program, decomposed_program + ) + else: + pre_decomp_unique_ops = None + post_decomp_unique_ops = None + + try: + # Convert the exported program to an ONNX model + onnx_program = _exported_program_to_onnx_program( + decomposed_program, registry=registry + ) + + # Run the ONNX passes + if input_names: + _ir_passes.rename_inputs(onnx_program.model, input_names) + if output_names: + _ir_passes.rename_outputs(onnx_program.model, output_names) + + # TODO(justinchuby): Remove the hack + _ir_passes.add_torchlib_common_imports(onnx_program.model) + + export_status.onnx_translation = True + verbose_print("Translate the graph into ONNX... ✅") + except Exception as e: + export_status.onnx_translation = False + verbose_print("Translate the graph into ONNX... ❌") + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + + # Run the analysis to get the error report + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + program, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + export_status=export_status, + profile_result=profile_result, + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + else: + report_path = None + + raise errors.OnnxConversionError( + _STEP_TWO_ERROR_MESSAGE + + (f"\nError report has been saved to '{report_path}'." if report else "") + + _summarize_exception_stack(e) + ) from e + + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if not verify: + # Return if verification is not requested + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + "No errors" + if not failed_results + else _format_exceptions_for_all_strategies(failed_results), + onnx_program.exported_program, + profile_result=profile_result, + export_status=export_status, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + elif profile and profile_result is not None: + verbose_print("Profile result:") + verbose_print(profile_result) + return onnx_program + + # Step 3: (verify=True) Check the ONNX model with ONNX checker + try: + verbose_print("Run `onnx.checker` on the ONNX model...") + + # TODO: Handle when model is >2GB + + model_proto = onnx_program.model_proto + byte_size = model_proto.ByteSize() + if byte_size < 2 * 1024 * 1024 * 1024: + # The checker may segfault so we need to run it in a separate process + _isolated.safe_call( + onnx.checker.check_model, onnx_program.model_proto, full_check=True # type: ignore[attr-defined] + ) + export_status.onnx_checker = True + verbose_print("Run `onnx.checker` on the ONNX model... ✅") + else: + verbose_print( + f"Run `onnx.checker` on the ONNX model... ⚠️ Skipped because model is too large ({byte_size})." + ) + except Exception as e: + export_status.onnx_checker = False + verbose_print("Run `onnx.checker` on the ONNX model... ❌") + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + onnx_program.exported_program, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + export_status=export_status, + profile_result=profile_result, + model=onnx_program.model, + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + logger.warning( + "Conversion successful but the ONNX model fails ONNX checker. " # noqa: G004 + "Please create an issue " + f"in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component and " + "attach the full error stack as well as reproduction scripts. ", + exc_info=e, + ) + return onnx_program + + # Step 4: (verify=True) Execute the model with ONNX Runtime + try: + verbose_print("Execute the model with ONNX Runtime...") + verification_results = _verification.verify_onnx_program(onnx_program) + verbose_print("Execute the model with ONNX Runtime... ✅") + export_status.onnx_runtime = True + onnx_runtime_error_message = None + except Exception as e: + verbose_print("Execute the model with ONNX Runtime... ❌") + export_status.onnx_runtime = False + onnx_runtime_error_message = _format_exception(e) + verification_message = None + + else: + # Step 5: (verify=True) Validate the output values + verbose_print("Verify output accuracy...") + export_status.output_accuracy = True + for verification_result in verification_results: + # TODO(justinchuby): The threshold is arbitrary right now + if verification_result.absolute_difference >= 5e-3: + logger.warning( + "Output '%s' has a large absolute difference of %f. ", + verification_result.name, + verification_result.absolute_difference, + ) + export_status.output_accuracy = False + if verification_result.relative_difference >= 1e-1: + logger.warning( + "Output '%s' has a large relative difference of %f. ", + verification_result.name, + verification_result.relative_difference, + ) + export_status.output_accuracy = False + if export_status.output_accuracy: + verbose_print("Verify output accuracy... ✅") + else: + verbose_print("Verify output accuracy... ❌") + verification_message = _reporting.format_verification_infos( + verification_results + ) + + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + + traceback_lines = [] + if failed_results: + traceback_lines.append( + _format_exceptions_for_all_strategies(failed_results) + ) + if onnx_runtime_error_message: + traceback_lines.append( + "# ⚠️ ONNX Runtime error -----------------------" + ) + traceback_lines.append(onnx_runtime_error_message) + if not traceback_lines: + traceback_lines.append("No errors") + + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + "\n\n".join(traceback_lines), + onnx_program.exported_program, + profile_result=profile_result, + export_status=export_status, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + model=onnx_program.model, + registry=registry, + verification_result=verification_message, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + + # Release the inference session created during verification + onnx_program.release() + return onnx_program diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py new file mode 100644 index 00000000000..3797a6d1fbd --- /dev/null +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -0,0 +1,74 @@ +"""Build decomp table from PyTorch.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING + +import torch +import torch._ops + + +if TYPE_CHECKING: + from torch.onnx._internal.exporter import _registration + + +def get_onnx_implemented_overloads( + registry: _registration.ONNXRegistry, +) -> list[torch._ops.OperatorBase]: + """ + Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. + + Args: + registry: The ONNX registry for PyTorch. + + Returns: + A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. + """ + registered_ops: list[torch._ops.OperatorBase] = [] + for op_namespace in (torch.ops.aten, torch.ops.prims): + op_names = dir(op_namespace) + for op_name in op_names: + op_overload_packet = getattr(op_namespace, op_name) + if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): + continue + + for overload_name in op_overload_packet.overloads(): + op_overload = getattr(op_overload_packet, overload_name) + if registry.is_registered(op_overload): + registered_ops.append(op_overload) + return registered_ops + + +def create_onnx_friendly_decomposition_table( + registry, +) -> dict[torch._ops.OperatorBase, Callable]: + """ + This function creates a dictionary of op overloads and their decomposition functions + for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, + its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's + built-in aten-to-aten decomposition. + + Args: + registry: The ONNX registry for PyTorch. + + Returns: + Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding + decomposition functions. + """ + decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} + onnx_registered_ops = set(get_onnx_implemented_overloads(registry)) + + # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single + # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your + # definitions in a single TORCH_LIBRARY block. + for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined] + # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX + # symbolic function. + # NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is + # not exportable anyways. + if op_overload in onnx_registered_ops: + continue + decomposition_table[op_overload] = decomp_fn + + return decomposition_table diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py new file mode 100644 index 00000000000..b8aecfaa937 --- /dev/null +++ b/torch/onnx/_internal/exporter/_dispatching.py @@ -0,0 +1,345 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Sequence + +import onnxscript +from onnxscript import ir + +import torch +import torch.fx +from torch.onnx._internal.exporter import _registration, _schemas + + +logger = logging.getLogger(__name__) + +# Define utilities to convert PyTorch data types so users do not need to specify manually +_TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.DOUBLE, + torch.complex64: ir.DataType.FLOAT, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, +} + + +def _torch_dtype_to_onnx_compatible_dtype(dtype: torch.dtype) -> ir.DataType: + return _TORCH_DTYPE_TO_ONNX_COMPATIBLE[dtype] + + +def _attribute_type_compatible_with_arg( + attr: _schemas.AttributeParameter, + value: ir.Value | int | float | bool | Sequence[int] | Sequence[float] | None, +) -> bool: + """Check if the attribute type is compatible with the argument.""" + if isinstance(value, bool): + return attr.type is ir.AttributeType.INT + if isinstance(value, str): + return attr.type is ir.AttributeType.STRING + if isinstance(value, int): + return attr.type in {ir.AttributeType.INT, ir.AttributeType.FLOAT} + if isinstance(value, float): + return attr.type is ir.AttributeType.FLOAT + if isinstance(value, complex): + return False + if isinstance(value, Sequence): + if attr.type is ir.AttributeType.INTS: + return all(isinstance(i, int) for i in value) + if attr.type is ir.AttributeType.FLOATS: + return all(isinstance(i, (int, float)) for i in value) + if isinstance(value, torch.dtype): + return attr.type is ir.AttributeType.INT + if isinstance(value, (torch.device, torch.memory_format, torch.layout)): + return attr.type is ir.AttributeType.STRING + if value is None and not attr.required: + # An optional attribute is not supplied + return True + return False + + +def _param_type_compatible_with_arg( + param: _schemas.Parameter, + value: ir.TypeProtocol + | str + | int + | float + | complex + | Sequence[int] + | Sequence[float] + | None, + assigned_types: dict[str, ir.TypeProtocol], +) -> bool: + # Handle Python types first + if isinstance(value, bool): # noqa: SIM102 + if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}: + return True + if isinstance(value, int) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.INT4), + ir.TensorType(ir.DataType.INT8), + ir.TensorType(ir.DataType.INT16), + ir.TensorType(ir.DataType.INT32), + ir.TensorType(ir.DataType.INT64), + # Int inputs can be casted to a float too + ir.TensorType(ir.DataType.FLOAT8E4M3FN), + ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), + ir.TensorType(ir.DataType.FLOAT8E5M2), + ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), + ir.TensorType(ir.DataType.FLOAT16), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + }: + return True + if isinstance(value, float) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT8E4M3FN), + ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), + ir.TensorType(ir.DataType.FLOAT8E5M2), + ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), + ir.TensorType(ir.DataType.FLOAT16), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + }: + return True + if isinstance(value, complex) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.TensorType(ir.DataType.COMPLEX64), + ir.TensorType(ir.DataType.COMPLEX128), + }: + return True + if isinstance(value, str): # noqa: SIM102 + if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}: + return True + if isinstance(value, (list, tuple)): + if param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.INT32), + ir.TensorType(ir.DataType.INT64), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.SequenceType(ir.TensorType(ir.DataType.INT32)), + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), + } and all(isinstance(i, (int)) for i in value): + # We will just allow any fx node and trust that the overload handles it + return True + if param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), + } and all(isinstance(i, (int, float)) for i in value): + # We will just allow any fx node and trust that the overload handles it + return True + if value is None and not param.required: + # An optional parameter is not supplied + return True + + if not isinstance(value, ir.TypeProtocol): + return False + + # Then check tensor types + if param.type_constraint.name in assigned_types: + # If a typevar is already bound, check if the value has the same type + assigned_type = assigned_types[param.type_constraint.name] + return assigned_type == value + # If the typevar is not bound, bind it to the value type + if value in param.type_constraint.allowed_types: + # TODO: Maybe just check dtype? Being more strict here for now + assigned_types[param.type_constraint.name] = value + return True + return False + + +def _get_type_from_tensor( + tensor: torch.Tensor | Sequence[torch.Tensor], +) -> ir.TypeProtocol: + if isinstance(tensor, torch.Tensor): + return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype)) + first_tensor = next((item for item in tensor if item is not None), None) + if first_tensor is None: + return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED)) + return ir.SequenceType( + ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(first_tensor.dtype)) + ) + + +def _get_first_tensor_in_node_list( + nodes: Sequence[torch.fx.Node | None], +) -> torch.Tensor | None: + for node in nodes: + if ( + node is not None + and "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + ): + return node.meta["val"] + return None + + +def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]: + # FIXME: node.target may not have a schema + torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr] + node_args = {} + for arg, schema_arg in zip(node.args, torch_schema.arguments): + node_args[schema_arg.name] = arg + + node_args.update(node.kwargs) + return node_args + + +def get_matching_overload( + node: torch.fx.Node, + overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction], +) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: + """Get the overload that matches the node's arguments. + + Args: + node: The node to match. + overloads: The overloads to match against. + + Returns: + A tuple containing the matched overload and a string describing the reason for failure or success. + """ + named_args = _get_named_fx_node_args(node) + # FIXME: node.target may and builtin and not have a schema + # FIXME: Handle when we don't know the names of the arguments + schema_args: dict[str, torch.Argument] = { + arg.name: arg + for arg in node.target._schema.arguments # type: ignore[union-attr] + } + failure_messages: list[str] = [] + for overload in overloads: + assigned_types: dict[str, ir.TypeProtocol] = {} + fail_reason = "" + if not hasattr(overload, "signature"): + # When an overload does not have a signature, we assume it is a custom op and should be matched + return ( + overload, + "The overload does not have a signature. Assuming it is a custom op and matching it.", + ) + for param in overload.signature: + if param.name not in schema_args and param.required: + # We don't need to handle variadic inputs as there is none. + # A required parameter is not supplied. + fail_reason = "Required parameter not supplied" + break + + # Get the argument + if param.name in named_args: + # Provided in Node args + arg = named_args[param.name] + elif ( + param.name in schema_args + and schema_args[param.name].has_default_value() + ): + # Provided in schema args + arg = schema_args[param.name].default_value + elif param.has_default(): + # Provided in the ONNX op definition + arg = param.default + else: + fail_reason = "Parameter not provided" + break + + if isinstance(param, _schemas.Parameter): + if isinstance(arg, torch.Tensor): + arg = _get_type_from_tensor(arg) # type: ignore[assignment] + if isinstance(arg, (list, tuple)) and any( + isinstance(t, torch.fx.Node) for t in arg + ): + first_tensor = _get_first_tensor_in_node_list(arg) + assert first_tensor is not None + # FIXME: Handle symfloat here + arg = ir.SequenceType(_get_type_from_tensor(first_tensor)) # type: ignore[assignment] + elif isinstance(arg, torch.fx.Node): + meta_val = arg.meta["val"] + arg = _get_type_from_tensor(meta_val) # type: ignore[assignment] + # TODO: Handle None attributes + # FIXME: Handle symfloat etc. + # Handle tensors and Python values + if not _param_type_compatible_with_arg(param, arg, assigned_types): # type: ignore[arg-type] + fail_reason = ( + f"Parameter type not compatible with argument: param=`{param}`, " + f"assigned_types=`{assigned_types}`, arg=`{arg}`" + ) + break + elif isinstance(param, _schemas.AttributeParameter): + if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type] + fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`" + break + if not fail_reason: + return overload, "Successfully matched overload" + else: + failure_messages.append( + f"- Failed to match overload `{overload}`: {fail_reason}" + ) + return ( + None, + f"All overloads did not match the node `{node.format_node()}`.\n" + + "\n".join(failure_messages), + ) + + +def _arg_has_complex_dtype(arg) -> bool: + """Check if the node has complex dtype recursively.""" + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + and torch.is_complex(arg.meta["val"]) + ): + return True + elif isinstance(arg, list): + return any(_arg_has_complex_dtype(item) for item in arg) + return False + + +def dispatch( + node: torch.fx.Node, registry: _registration.ONNXRegistry +) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: + """Dispatch a node to an ONNX function based on the node's target and the ONNX registry. + + Args: + node: The node to dispatch. + registry: The ONNX registry to use for dispatching. + + Returns: + A tuple containing the matched ONNX function and a string describing the reason for failure or success. + """ + # TODO: Handle when node does not have a target + decomp_metas = registry.get_decomps(node.target) # type: ignore[arg-type] + # Determine if the node has complex inputs. + is_complex = any(_arg_has_complex_dtype(arg) for arg in node.args) or any( + _arg_has_complex_dtype(arg) for arg in node.kwargs.values() + ) + if is_complex: + decomp_metas = [decomp for decomp in decomp_metas if decomp.is_complex] + if not decomp_metas: + return None, "No decompositions registered for the complex-valued input" + else: + decomp_metas = [decomp for decomp in decomp_metas if not decomp.is_complex] + if not decomp_metas: + return None, "No decompositions registered for the real-valued input" + + if len(decomp_metas) == 1: + return ( + decomp_metas[0].onnx_function, + "Fast path: Only one decomposition is defined", + ) + + overload, message = get_matching_overload( + node, [decomp.onnx_function for decomp in decomp_metas] + ) + return overload, message diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py new file mode 100644 index 00000000000..2feae57b5d7 --- /dev/null +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -0,0 +1,72 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import torch +import torch.export +import torch.fx +from torch.onnx._internal.exporter import _decomp, _registration +from torch.onnx._internal.fx import diagnostics, passes + + +_ATEN_ASSERTION_TARGETS = frozenset( + { + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.msg, + } +) + + +def decompose_with_registry( + exported_program: torch.export.ExportedProgram, registry: _registration.ONNXRegistry +) -> torch.export.ExportedProgram: + """Decompose the exported program with the given registry. + + This function is needed so it shows clearly on the profiler results. + """ + decomp_table = _decomp.create_onnx_friendly_decomposition_table(registry) + onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) + # Try to preserve some known CompositeImplicitAutograd ops + aten = torch.ops.aten + to_preserve = { + aten._upsample_bilinear2d_aa.default, + aten._upsample_nearest_exact1d.vec, + aten._upsample_nearest_exact2d.vec, + aten._upsample_nearest_exact3d.vec, + aten.group_norm.default, + aten.linear.default, + aten.upsample_bilinear2d.default, + aten.upsample_bilinear2d.vec, + aten.upsample_linear1d.default, + aten.upsample_linear1d.vec, + aten.upsample_nearest1d.default, + aten.upsample_nearest1d.vec, + aten.upsample_nearest2d.default, + aten.upsample_nearest2d.vec, + aten.upsample_nearest3d.default, + aten.upsample_nearest3d.vec, + aten.upsample_trilinear3d.default, + aten.upsample_trilinear3d.vec, + } + # We can only preserve implemented ops + can_preserve = tuple(to_preserve.intersection(onnx_registered_ops)) + return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve) + + +def insert_type_promotion_nodes( + graph_module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Inplace pass to insert explicit type promotion nodes.""" + diagnostic_context = diagnostics.DiagnosticContext( + "torch.onnx.export", + torch.__version__, + ) + return passes.InsertTypePromotion(diagnostic_context, graph_module).run() + + +def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove all assertion and check nodes from the FX graph""" + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _ATEN_ASSERTION_TARGETS: + graph_module.graph.erase_node(node) + graph_module.recompile() + return graph_module diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py new file mode 100644 index 00000000000..7e8748443e2 --- /dev/null +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Sequence + +from onnxscript import ir + + +logger = logging.getLogger(__name__) + + +def rename_inputs(model: ir.Model, new_names: Sequence[str]) -> None: + # TODO: Ensure the names do not have duplicates + for input, new_name in zip(model.graph.inputs, new_names): + input.metadata_props["pkg.torch.onnx.original_node_name"] = str(input.name) + input.name = new_name + + +def rename_outputs(model: ir.Model, new_names: Sequence[str]) -> None: + for output, new_name in zip(model.graph.outputs, new_names): + output.metadata_props["pkg.torch.onnx.original_node_name"] = str(output.name) + output.name = new_name + + +def add_torchlib_common_imports(model: ir.Model) -> None: + """Hack to add torchlib common imports to the model.""" + + try: + # TODO(justinchuby): Remove this hack and improved onnxscript + from onnxscript.function_libs.torch_lib.ops import common as common_ops + + model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 + rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + is_scalar_func = ir.serde.deserialize_function( + common_ops.IsScalar.to_function_proto() + ) + model.functions[rank_func.identifier()] = rank_func + model.functions[is_scalar_func.identifier()] = is_scalar_func + except Exception: + logger.exception("Failed to add torchlib common imports to the model.") diff --git a/torch/onnx/_internal/exporter/_isolated.py b/torch/onnx/_internal/exporter/_isolated.py new file mode 100644 index 00000000000..4a5c5fcdf79 --- /dev/null +++ b/torch/onnx/_internal/exporter/_isolated.py @@ -0,0 +1,55 @@ +"""Isolated calls to methods that may segfault.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import multiprocessing +import os +import warnings +from typing import Callable + + +_IS_WINDOWS = os.name == "nt" + + +def _call_function_and_return_exception(func, args, kwargs): + """Call function and return a exception if there is one.""" + + try: + return func(*args, **kwargs) + except Exception as e: + return e + + +def safe_call(func: Callable, *args, **kwargs): + """Call a function in a separate process. + + Args: + func: The function to call. + args: The positional arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + + Returns: + The return value of the function. + + Raises: + Exception: If the function raised an exception. + """ + if _IS_WINDOWS: + # On Windows, we cannot create a new process with fork. + warnings.warn( + f"A new process is not created for {func} on Windows.", stacklevel=1 + ) + return func(*args, **kwargs) + + with multiprocessing.get_context("fork").Pool(1) as pool: + # It is important to fork a process here to prevent the main logic from + # running again when the user does not place it under a `if __name__ == "__main__":` + # block. + result = pool.apply_async( + _call_function_and_return_exception, (func, args, kwargs) + ) + result = result.get(timeout=5) + if isinstance(result, Exception): + raise result + return result diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py new file mode 100644 index 00000000000..51e20207877 --- /dev/null +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -0,0 +1,288 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="attr-defined,name-defined" +from __future__ import annotations + + +__all__ = ["ONNXProgram"] + +import gc +import logging +import os +import pathlib +import tempfile +import textwrap +from typing import Callable, IO, Sequence, TYPE_CHECKING + +import torch +from torch.onnx._internal import _lazy_import +from torch.utils import _pytree as pytree + + +onnx = _lazy_import.onnx +ir = _lazy_import.onnxscript_ir + + +if TYPE_CHECKING: + import onnxruntime as ort + +logger = logging.getLogger(__name__) + + +def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: + """Initialize an ONNX Runtime inference session with the specified model.""" + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.log_severity_level = 3 # 3: Error + possible_providers = ( + "CUDAExecutionProvider", + "CPUExecutionProvider", + ) + available_providers = set(ort.get_available_providers()) + providers = [ + provider for provider in possible_providers if provider in available_providers + ] + return ort.InferenceSession( + model, providers=providers, sess_options=session_options + ) + + +class ONNXProgram: + """A substitute class for `torch.onnx.ONNXProgram`.""" + + def __init__(self, model: ir.Model, exported_program: torch.export.ExportedProgram): + self.model: ir.Model = model + self.exported_program = exported_program + self._inference_session: ort.InferenceSession | None = None + self._tempdir: tempfile.TemporaryDirectory | None = None + + def __repr__(self) -> str: + return f"""\ +ONNXProgram( + model= +{textwrap.indent(str(self.model), ' ' * 8)} + , + exported_program= +{textwrap.indent(str(self.exported_program), ' ' * 8)} +) +""" + + def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]: + """Run the ONNX model with the same arguments you would provide to the GraphModule.""" + import onnxruntime as ort + + flatten_args = _process_args(args, kwargs) + + if self._inference_session is None: + self.initialize_inference_session() + + assert self._inference_session is not None + + # We don't expect non-tensor as inputs + ort_input = { + k.name: v.numpy(force=True) + for k, v in zip(self.model.graph.inputs, flatten_args) + } + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + logger.debug("Running the inference session with %s arguments.", len(ort_input)) + outputs = self._inference_session.run(None, ort_input, run_options=run_options) + logger.debug("Inference session run completed.") + # TODO(justinchuby): Maybe output complex tensors as needed + return tuple(torch.from_numpy(output) for output in outputs) + + @property + def model_proto(self) -> onnx.ModelProto: + """Compatibility property for `torch.onnx.ONNXProgram.model_proto`.""" + return ir.serde.serialize_model(self.model) + + def save( + self, + destination: str | os.PathLike | IO[bytes], + *, + include_initializers: bool = True, + keep_initializers_as_inputs: bool = False, + external_data: bool | None = None, + **_, + ): + """Save the ONNX model to the specified destination. + + When `external_data` is `True` or the model is larger than 2GB, + the weights are saved as external data in a separate file. + + Args: + destination: The path to save the ONNX model to. + include_initializers: Whether to include the initializers in the saved model. + keep_initializers_as_inputs: Whether to keep the initializers as inputs in the saved model. + If `True`, the initializers are added as inputs to the model which means they can be overwritten. + by providing the initializers as model inputs. + external_data: Whether to save the weights as external data in a separate file. + + Raises: + TypeError: If `external_data` is `True` and `destination` is not a file path. + """ + if not include_initializers: + self.model.graph.initializers.clear() + logger.warning( + "The initializers have been removed from the model. This is destructive. " + "Developers: Please implement ir.Model copy() and remove initializers on the copied model." + ) + if keep_initializers_as_inputs: + self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type] + logger.warning( + "The initializers have been added as inputs to the model. This is destructive. " + "Developers: Please implement ir.Model copy() and remove initializers on the copied model." + ) + proto = ir.serde.serialize_model(self.model) + byte_size = proto.ByteSize() + model_too_large = (byte_size) >= 1 << 31 + if external_data or model_too_large: + # TODO: Create an IR pass to handle external tensors conversion + if model_too_large: + logger.warning( + "The serialized ONNX model is larger than 2GB (%s). " + "Saving the weights as external data in a separate file.", + byte_size, + ) + if not isinstance(destination, (str, os.PathLike)): + raise TypeError( + "Saving the weights as external data is only supported when destination is a file path" + ) + destination_path = pathlib.Path(destination) + # Create the directory if it does not exist + data_path = f"{destination_path.name}.data" + onnx.save_model( + proto, + destination, + save_as_external_data=True, + location=data_path, + ) + else: + onnx.save_model(proto, destination) + + def initialize_inference_session( + self, + initializer: Callable[ + [str | bytes], ort.InferenceSession + ] = _ort_session_initializer, + ) -> None: + """Initialize the ONNX Runtime inference session. + + Args: + initializer: The function to initialize the ONNX Runtime inference + session with the specified model. By default, it uses the + :func:`_ort_session_initializer` function. + """ + # TODO(justinchuby): Allow different inference options + logger.debug("Initializing the inference session.") + proto = ir.serde.serialize_model(self.model) + byte_size = proto.ByteSize() + model_too_large = (byte_size) >= 1 << 31 + + if model_too_large: + logger.debug( + "The serialized ONNX model is larger than 2GB (%s).", byte_size + ) + # Save the model to a temporary file if too large + self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + model_path = os.path.join(self._tempdir.name, "model.onnx") + data_path = "model.onnx.data" + onnx.save_model( + proto, + model_path, + save_as_external_data=True, + location=data_path, + ) + model = model_path + else: + model = proto.SerializeToString() # type: ignore[assignment] + + self._inference_session = initializer(model) + logger.debug("Inference session initialized.") + + def release(self) -> None: + """Release the inference session. + + You may call this method to release the resources used by the inference session. + """ + # Release the inference session first so that the model file can be deleted + if self._inference_session is not None: + self._inference_session = None + gc.collect() + if self._tempdir is not None: + self._tempdir.cleanup() + self._tempdir = None + + +def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: + """Process input arguments for the ONNX model.""" + args = _flatten_inputs(args, kwargs) + args = _remove_none_from_inputs(args) + args = _remove_non_tensor(args) + args = _convert_complex_to_real_representation(args) + return args + + +def _flatten_inputs(model_args, model_kwargs): + flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs)) + return flattened_args + + +def _remove_none_from_inputs(model_args): + return tuple(arg for arg in model_args if arg is not None) + + +def _remove_non_tensor(model_args): + """Remove the non-tensor input arguments. + + Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). + + Specifically, it does put the input into graph with an empty node, but consumed by no ones. + The concrete value is embedded into the graph as a constant arg of a target node. Meta + suggests in this case that one should rewrite the model code to make it tensor if the + input value is supposed to change at runtime. We might need to further investigate + the feasibility of that suggestion. + + For example, + + def func(x, b=1.0): + y = x + b + z = y.relu() + return (y, z) + + x = torch.randn(1, 1, 2, dtype=torch.float32) + gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") + + # class GraphModule(torch.nn.Module): + # def forward(self, x, b): + # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) + # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b + # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None + + # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() + # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) + # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) + + Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as + it's ignored in ONNX graph. Thus, we delete the useless input here. + + """ + + return tuple( + arg for arg in model_args if not isinstance(arg, (int, float, bool, str)) + ) + + +def _convert_complex_to_real_representation(model_args): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + """ + return tuple( + torch.view_as_real(arg.resolve_conj()) + if isinstance(arg, torch.Tensor) and arg.is_complex() + else arg + for arg in model_args + ) diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py new file mode 100644 index 00000000000..b649188c264 --- /dev/null +++ b/torch/onnx/_internal/exporter/_registration.py @@ -0,0 +1,275 @@ +"""Module for handling ATen to ONNX functions registration. + +https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py +""" + +# NOTE: Why do we need a different registry than the one in torchlib? +# The registry in torchlib is used to register functions that are already implemented in +# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different +# opsets etc. The registry implemented for the exporter is designed to be modifiable at +# export time by users, and is designed with dispatching in mind. + +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import logging +import math +import operator +import types +import typing +from typing import Callable, Literal, Mapping, Union +from typing_extensions import TypeAlias + +import torch +import torch._ops +from torch.onnx._internal.exporter import _schemas + + +if typing.TYPE_CHECKING: + import onnxscript + from onnxscript.function_libs.torch_lib import registration as torchlib_registration + +_DEFAULT_OPSET_VERSION = 18 + + +TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable] + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class OnnxDecompMeta: + """A wrapper of onnx-script function with additional metadata. + + onnx_function: The onnx-script function from torchlib. + fx_target: The PyTorch node callable target. + is_custom: Whether the function is a custom function. + is_complex: Whether the function is a function that handles complex valued inputs. + device: The device the function is registered to. If None, it is registered to all devices. + """ + + onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction + fx_target: TorchOp + is_custom: bool = False + is_complex: bool = False + device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 + + +def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None: + """Obtain the torch op from ::[.]""" + # TODO(justinchuby): Handle arbitrary custom ops + namespace, opname_overload = qualified_name.split("::") + op_name, *maybe_overload = opname_overload.split(".", 1) + if namespace == "_operator": + # Builtin functions + return getattr(operator, op_name) + if namespace == "math": + return getattr(math, op_name) + if namespace == "torchvision": + try: + import torchvision.ops # type: ignore[import-untyped] + except ImportError: + logger.warning("torchvision is not installed. Skipping %s", qualified_name) + return None + try: + return getattr(torchvision.ops, op_name) + except AttributeError: + logger.warning("Failed to find torchvision op '%s'", qualified_name) + return None + except Exception: + logger.exception("Failed to find torchvision op '%s'", qualified_name) + try: + op_packet = getattr(getattr(torch.ops, namespace), op_name) + if maybe_overload: + overload = maybe_overload[0] + elif "default" in op_packet._overload_names or "" in op_packet._overload_names: + # Has a default overload + overload = "default" + else: + logger.warning( + "'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.", + qualified_name, + stacklevel=1, + ) + return None + + return getattr(op_packet, overload) # type: ignore[call-overload] + except AttributeError: + if qualified_name.endswith("getitem"): + # This is a special case where we registered the function incorrectly, + # but for BC reasons (pt<=2.4) we need to keep it. + return None + logger.info("'%s' is not found in this version of PyTorch.", qualified_name) + return None + except Exception: + logger.exception("Failed to find torch op '%s'", qualified_name) + return None + + +class ONNXRegistry: + """Registry for ONNX functions. + + The registry maintains a mapping from qualified names to symbolic functions under a + fixed opset version. It supports registering custom onnx-script functions and for + dispatcher to dispatch calls to the appropriate function. + + """ + + def __init__(self) -> None: + """Initializes the registry""" + + # TODO: Design multi-opset version support + self._opset_version = _DEFAULT_OPSET_VERSION + + self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {} + + @property + def opset_version(self) -> int: + """The ONNX opset version the exporter should target. + + Defaults to the latest supported ONNX opset version: 18. + The default version will increment over time as ONNX continues to evolve. + """ + + return self._opset_version + + @classmethod + def from_torchlib( + cls, + torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction] + | None = None, + ) -> ONNXRegistry: + """Populates the registry with ATen functions from torchlib. + + Args: + torchlib_registry: The torchlib registry to use for populating the registry. + """ + registry = cls() + if torchlib_registry is None: + from onnxscript.function_libs.torch_lib import ( + registration as torchlib_registration, + ) + + torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment] + for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr] + try: + # NOTE: This is heavily guarded with try-except because we don't want + # to fail the entire registry population if one function fails. + if qualified_name.startswith("internal::"): + # Skip the custom defined internal functions + continue + target = _get_overload(qualified_name) + if target is None: + continue + for overload_func in aten_overloads_func.overloads: + overload_func.signature = _schemas.OpSignature.from_function( + overload_func, + overload_func.function_ir.domain, + overload_func.name, + ) + onnx_decomposition = OnnxDecompMeta( + onnx_function=overload_func, + fx_target=target, + is_custom=False, + is_complex=False, + ) + registry._register(target, onnx_decomposition) + + for complex_func in aten_overloads_func.complex: + overload_func.signature = _schemas.OpSignature.from_function( + overload_func, + overload_func.function_ir.domain, + overload_func.name, + ) + onnx_decomposition = OnnxDecompMeta( + onnx_function=complex_func, + fx_target=target, + is_custom=False, + is_complex=True, + ) + registry._register(target, onnx_decomposition) + except Exception: + logger.exception("Failed to register '%s'. Skipped", qualified_name) + continue + return registry + + def _register( + self, + target: TorchOp, + onnx_decomposition: OnnxDecompMeta, + ) -> None: + """Registers a OnnxDecompMeta to an operator. + + Args: + target: The PyTorch node callable target. + onnx_decomposition: The OnnxDecompMeta to register. + """ + target_or_name: str | TorchOp + if isinstance(target, torch._ops.OpOverload): + # Get the qualified name of the aten op because torch._ops.OpOverload lookup in + # a dictionary is unreliable for some reason. + target_or_name = target.name() + else: + target_or_name = target + if onnx_decomposition.is_custom: + self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition) + else: + self.functions.setdefault(target_or_name, []).append(onnx_decomposition) + + def register_op( + self, + target: TorchOp, + function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + is_complex: bool = False, + ) -> None: + """Registers a custom operator: torch.ops.... + + Args: + target: The PyTorch node callable target. + function: The onnx-script function to register. + is_complex: Whether the function is a function that handles complex valued inputs. + """ + onnx_decomposition = OnnxDecompMeta( + onnx_function=function, + fx_target=target, + is_custom=True, + is_complex=is_complex, + ) + self._register(target, onnx_decomposition) + + def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]: + """Returns a list of OnnxDecompMeta for the given op: torch.ops.... + + The list is ordered by the time of registration. The custom operators should come + first in the list. + + Args: + target: The PyTorch node callable target. + Returns: + A list of OnnxDecompMeta corresponding to the given name, or None if + the name is not in the registry. + """ + target_or_name: str | TorchOp + if isinstance(target, torch._ops.OpOverload): + # Get the qualified name of the aten op because torch._ops.OpOverload lookup in + # a dictionary is unreliable for some reason. + target_or_name = target.name() + else: + target_or_name = target + decomps = self.functions.get(target_or_name, []) + return sorted(decomps, key=lambda x: x.is_custom, reverse=True) + + def is_registered(self, target: TorchOp) -> bool: + """Returns whether the given op is registered: torch.ops.... + + Args: + target: The PyTorch node callable target. + + Returns: + True if the given op is registered, otherwise False. + """ + return bool(self.get_decomps(target)) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/torch/onnx/_internal/exporter/_reporting.py b/torch/onnx/_internal/exporter/_reporting.py new file mode 100644 index 00000000000..55a77a90ec4 --- /dev/null +++ b/torch/onnx/_internal/exporter/_reporting.py @@ -0,0 +1,193 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import re +from typing import TYPE_CHECKING + +from torch.onnx._internal.exporter import _analysis, _registration, _verification + + +if TYPE_CHECKING: + import os + + from onnxscript import ir + + import torch + + +@dataclasses.dataclass +class ExportStatus: + # Whether torch.export.export.export() succeeds + torch_export: bool | None = None + # Whether torch.export.export.export(..., strict=False) succeeds + torch_export_non_strict: bool | None = None + # Whether torch.jit.trace succeeds + torch_jit: bool | None = None + # Whether ONNX translation succeeds + onnx_translation: bool | None = None + # Whether ONNX model passes onnx.checker.check_model + onnx_checker: bool | None = None + # Whether ONNX model runs successfully with ONNX Runtime + onnx_runtime: bool | None = None + # Whether the output of the ONNX model is accurate + output_accuracy: bool | None = None + + +def _status_emoji(status: bool | None) -> str: + if status is None: + return "⚪" + return "✅" if status else "❌" + + +def _format_export_status(status: ExportStatus) -> str: + return ( + f"```\n" + f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n" + f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n" + f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n" + f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n" + f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n" + f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n" + f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n" + f"```\n\n" + ) + + +def _strip_color_from_string(text: str) -> str: + # This regular expression matches ANSI escape codes + # https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788 + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + + +def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str: + # Adapted from https://github.com/pytorch/pytorch/pull/128476 + # to remove colors + # Even though we can call graph_module.print_readable directly, since the + # colored option was added only recently, we can't guarantee that the + # version of PyTorch used by the user has this option. Therefore, we + # still call str(ExportedProgram) + text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n" + return text + + +def construct_report_file_name(timestamp: str, status: ExportStatus) -> str: + # Status could be None. So we need to check for False explicitly. + if not (status.torch_export or status.torch_export_non_strict or status.torch_jit): + # All strategies failed + postfix = "pt_export" + elif status.onnx_translation is False: + postfix = "conversion" + elif status.onnx_checker is False: + postfix = "checker" + elif status.onnx_runtime is False: + postfix = "runtime" + elif status.output_accuracy is False: + postfix = "accuracy" + elif status.torch_export is False or status.torch_export_non_strict is False: + # Some strategies failed + postfix = "strategies" + else: + postfix = "success" + return f"onnx_export_{timestamp}_{postfix}.md" + + +def format_decomp_comparison( + pre_decomp_unique_ops: set[str], + post_decomp_unique_ops: set[str], +) -> str: + """Format the decomposition comparison result. + + Args: + unique_ops_in_a: The unique ops in the first program. + unique_ops_in_b: The unique ops in the second program. + + Returns: + The formatted comparison result. + """ + return ( + f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n" + f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n" + ) + + +def format_verification_infos( + verification_infos: list[_verification.VerificationInfo], +) -> str: + """Format the verification result. + + Args: + verification_infos: The verification result. + + Returns: + The formatted verification result. + """ + return "\n".join( + f"`{info.name}`: `abs_diff={info.absolute_difference:e}`, `rel_diff={info.relative_difference:e}`" + for info in verification_infos + ) + + +def create_torch_export_error_report( + filename: str | os.PathLike, + formatted_traceback: str, + *, + export_status: ExportStatus, + profile_result: str | None, +): + with open(filename, "w", encoding="utf-8") as f: + f.write("# PyTorch ONNX Conversion Error Report\n\n") + f.write(_format_export_status(export_status)) + f.write("Error message:\n\n") + f.write("```pytb\n") + f.write(formatted_traceback) + f.write("```\n\n") + if profile_result is not None: + f.write("## Profiling result\n\n") + f.write("```\n") + f.write(profile_result) + f.write("```\n") + + +def create_onnx_export_report( + filename: str | os.PathLike, + formatted_traceback: str, + program: torch.export.ExportedProgram, + *, + decomp_comparison: str | None = None, + export_status: ExportStatus, + profile_result: str | None, + model: ir.Model | None = None, + registry: _registration.ONNXRegistry | None = None, + verification_result: str | None = None, +): + with open(filename, "w", encoding="utf-8") as f: + f.write("# PyTorch ONNX Conversion Report\n\n") + f.write(_format_export_status(export_status)) + f.write("## Error messages\n\n") + f.write("```pytb\n") + f.write(formatted_traceback) + f.write("\n```\n\n") + f.write("## Exported program\n\n") + f.write(_format_exported_program(program)) + if model is not None: + f.write("## ONNX model\n\n") + f.write("```python\n") + f.write(str(model)) + f.write("\n```\n\n") + f.write("## Analysis\n\n") + _analysis.analyze(program, file=f, registry=registry) + if decomp_comparison is not None: + f.write("\n## Decomposition comparison\n\n") + f.write(decomp_comparison) + f.write("\n") + if verification_result is not None: + f.write("\n## Verification results\n\n") + f.write(verification_result) + f.write("\n") + if profile_result is not None: + f.write("\n## Profiling result\n\n") + f.write("```\n") + f.write(profile_result) + f.write("```\n") diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py new file mode 100644 index 00000000000..8ad10cd7a87 --- /dev/null +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -0,0 +1,548 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections.abc +import dataclasses +import inspect +import logging +import types +import typing +from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union + +import onnx + +import onnxscript +from onnxscript import ir + + +logger = logging.getLogger(__name__) + + +# A special value to indicate that the default value is not specified +class _Empty: + def __repr__(self): + return "_EMPTY_DEFAULT" + + +_EMPTY_DEFAULT = _Empty() + +# Map from python type to corresponding ONNX AttributeProto type +_PY_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, + ir.Tensor: ir.AttributeType.TENSOR, + ir.TensorProtocol: ir.AttributeType.TENSOR, + ir.Graph: ir.AttributeType.GRAPH, + ir.GraphProtocol: ir.AttributeType.GRAPH, +} + +# Map from python type to corresponding ONNX AttributeProto type, +# for repeated (i.e., list of) values +_LIST_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, + ir.Tensor: ir.AttributeType.TENSORS, + ir.TensorProtocol: ir.AttributeType.TENSORS, + ir.Graph: ir.AttributeType.GRAPHS, + ir.GraphProtocol: ir.AttributeType.GRAPHS, +} + +_ALL_VALUE_TYPES = ( + {ir.TensorType(dtype) for dtype in ir.DataType} + | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} + | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} +) + +# TypeAnnotationValue represents the (value of) valid type-annotations recognized +# by ONNX Script. Currently, it supports +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached +TypeAnnotationValue = Any + + +@dataclasses.dataclass(frozen=True) +class TypeConstraintParam: + """Type constraint for a parameter. + + Attributes: + name: Name of the parameter. E.g. "TFloat" + allowed_types: Allowed types for the parameter. + """ + + name: str + allowed_types: set[ir.TypeProtocol] + description: str = "" + + def __hash__(self) -> int: + return hash((self.name, tuple(self.allowed_types))) + + def __str__(self) -> str: + allowed_types_str = " | ".join(str(t) for t in self.allowed_types) + return f"{self.name}={allowed_types_str}" + + @classmethod + def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) + + @classmethod + def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """A formal parameter of an operator.""" + + name: str + type_constraint: TypeConstraintParam + required: bool + variadic: bool + default: Any = _EMPTY_DEFAULT + # TODO: Add other properties too + + def __str__(self) -> str: + type_str = self.type_constraint.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not _EMPTY_DEFAULT + + +@dataclasses.dataclass(frozen=True) +class AttributeParameter: + name: str + type: ir.AttributeType + required: bool + default: ir.Attr | None = None + + def __str__(self) -> str: + type_str = self.type.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not None + + +def _get_type_from_str( + type_str: str, +) -> ir.TensorType | ir.SequenceType | ir.OptionalType: + """Converter a type_str from ONNX Opschema to ir.TypeProtocol. + + A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". + """ + + # TODO: Upstream this to IR + + # Split the type_str a sequence types and dtypes + # 1. Remove the ending ")" + striped = type_str.rstrip(")") + # 2. Split the type_str by "(" + type_parts = striped.split("(") + + # Convert the dtype to ir.DataType + dtype = ir.DataType[type_parts[-1].upper()] + + # Create a place holder type first + type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) + + # Construct the type + for type_part in reversed(type_parts[:-1]): + if type_part == "tensor": + type_ = ir.TensorType(dtype) + elif type_part == "seq": + type_ = ir.SequenceType(type_) + elif type_part == "optional": + type_ = ir.OptionalType(type_) + else: + raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") + return type_ # type: ignore[return-value] + + +def _convert_formal_parameter( + param: onnx.defs.OpSchema.FormalParameter, + type_constraints: Mapping[str, TypeConstraintParam], +) -> Parameter: + """Convert a formal parameter from ONNX Opschema to Parameter.""" + if param.type_str in type_constraints: + type_constraint = type_constraints[param.type_str] + else: + # param.type_str can be a plain type like 'int64'. + type_constraint = TypeConstraintParam( + name=param.name, + allowed_types={_get_type_from_str(param.type_str)}, + ) + return Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, + variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + ) + + +def _is_optional(type_: type) -> bool: + """Returns whether a type_ is an Optional.""" + origin_type = typing.get_origin(type_) + if origin_type is Union and type(None) in typing.get_args(type_): + # Python < 3.10 + return True + if origin_type is Optional: + # Python >= 3.10 + return True + if ( + hasattr(types, "UnionType") + and origin_type is types.UnionType + and type(None) in typing.get_args(type_) + ): + # Python >= 3.10 + return True + return False + + +def _get_attr_type(type_: type) -> ir.AttributeType: + """Obtain the type of the attribute from a Python class.""" + try: + if type_ in _PY_TYPE_TO_ATTR_TYPE: + return _PY_TYPE_TO_ATTR_TYPE[type_] + origin_type = typing.get_origin(type_) + if origin_type is None: + return ir.AttributeType.UNDEFINED + if origin_type in ( + collections.abc.Sequence, + Sequence, + typing.List, + list, + typing.Tuple, + tuple, + ): + inner_type = typing.get_args(type_)[0] + if inner_type in _LIST_TYPE_TO_ATTR_TYPE: + return _LIST_TYPE_TO_ATTR_TYPE[inner_type] + except TypeError: + logger.warning("TypeError when checking %s.", type_, exc_info=True) + return ir.AttributeType.UNDEFINED + + +def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: + """Returns the name of the type constraint for a given type annotation. + + Args: + type_: A Python type. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + """ + if isinstance(type_, TypeVar): + return type_.__name__ + if _is_optional(type_): + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = _get_type_constraint_name(subtype) + return type_param_name if type_param_name else None + origin_type = typing.get_origin(type_) + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + type_param_name = _get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None + + +def _get_allowed_types_from_type_annotation( + type_: TypeAnnotationValue, +) -> set[ir.TypeProtocol]: + """Obtain the allowed types from a type annotation.""" + if type_ is onnxscript.onnx_types.TensorType: + # Any tensor type + return {ir.TensorType(dtype) for dtype in ir.DataType} + + allowed_types: set[ir.TypeProtocol] + + if isinstance(type_, TypeVar): + allowed_types = set() + if constraints := type_.__constraints__: + for constraint in constraints: + allowed_types.update( + _get_allowed_types_from_type_annotation(constraint) + ) + else: + bound = type_.__bound__ + if bound is None: + allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] + else: + allowed_types.update(_get_allowed_types_from_type_annotation(bound)) + return allowed_types + if hasattr(type_, "dtype"): + # A single tensor type like INT64, FLOAT, etc. + return {ir.TensorType(ir.DataType(type_.dtype))} + if _is_optional(type_): + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. + return allowed_types + + origin_type = typing.get_origin(type_) + if origin_type is Union: + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + assert subtype is not type( + None + ), "Union should not contain None type because it is handled by _is_optional." + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + return allowed_types + + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + return { + ir.SequenceType(t) + for t in _get_allowed_types_from_type_annotation(subtypes[0]) + } + + # Allow everything by default + return _ALL_VALUE_TYPES # type: ignore[return-value] + + +@dataclasses.dataclass +class OpSignature: + """Schema for an operator. + + Attributes: + domain: Domain of the operator. E.g. "". + name: Name of the operator. E.g. "Add". + overload: Overload name of the operator. + params: Input parameters. When the op is an ONNX function definition, + the order is according to the function signature. This mean we can + interleave ONNX inputs and ONNX attributes in the list. + outputs: Output parameters. + """ + + domain: str + name: str + overload: str + params: Sequence[Parameter | AttributeParameter] + outputs: Sequence[Parameter] + params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( + init=False, repr=False + ) + + def __post_init__(self): + self.params_map = {param.name: param for param in self.params} + + def get(self, name: str) -> Parameter | AttributeParameter: + return self.params_map[name] + + def __contains__(self, name: str) -> bool: + return name in self.params_map + + def __iter__(self) -> Iterator[Parameter | AttributeParameter]: + return iter(self.params) + + def __str__(self) -> str: + domain = self.domain or "''" + # TODO: Double check the separator for overload + overload = f"::{self.overload}" if self.overload else "" + params = ", ".join(str(param) for param in self.params) + outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) + type_constraints = {} + for param in self.params: + if isinstance(param, Parameter): + type_constraints[param.type_constraint.name] = param.type_constraint + for param in self.outputs: + type_constraints[param.type_constraint.name] = param.type_constraint + type_constraints_str = ", ".join( + str(type_constraint) for type_constraint in type_constraints.values() + ) + return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + + @classmethod + def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature: + """Produce an OpSignature from an ONNX Opschema.""" + type_constraints = { + constraint.type_param_str: TypeConstraintParam( + name=constraint.type_param_str, + allowed_types={ + _get_type_from_str(type_str) + for type_str in constraint.allowed_type_strs + }, + description=constraint.description, + ) + for constraint in opschema.type_constraints + } + + params = [ + _convert_formal_parameter(param, type_constraints) + for param in opschema.inputs + ] + + for param in opschema.attributes.values(): + default_attr = ( + ir.serde.deserialize_attribute(param.default_value) + if param.default_value is not None + else None + ) + if default_attr is not None: + # Set the name of the default attribute because it may have a different name from the parameter + default_attr.name = param.name + params.append( + AttributeParameter( + name=param.name, + type=ir.AttributeType(param.type), # type: ignore[arg-type] + required=param.required, + default=default_attr, # type: ignore[arg-type] + ) + ) + + outputs = [ + _convert_formal_parameter(param, type_constraints) + for param in opschema.outputs + ] + + return cls( + domain=opschema.domain, + name=opschema.name, + overload="", + params=params, + outputs=outputs, + ) + + @classmethod + def from_function( + cls, func, domain: str, name: str | None = None, overload: str = "" + ) -> OpSignature: + """Produce an OpSignature from a function using type annotation.""" + + py_signature = inspect.signature(func) + # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases + # https://github.com/python/cpython/issues/102405 + type_hints = typing.get_type_hints(func) + + params = [] + # Create a mapping from type to a unique name + type_constraints: dict[str, TypeConstraintParam] = {} + + for param in py_signature.parameters.values(): + if param.name not in type_hints: + logger.warning( + "Missing annotation for parameter '%s' from %s. Treating as an Input.", + param.name, + py_signature, + ) + type_constraints[param.name] = TypeConstraintParam.any_value( + f"T_{param.name}" + ) + else: + type_ = type_hints[param.name] + if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + # Construct the default attribute + if param.default is not inspect.Parameter.empty: + # TODO: Use ir_convenience instead to handle int as float + default = ir.Attr(param.name, attr_type, param.default) + else: + default = None + params.append( + AttributeParameter( + name=param.name, + type=attr_type, + required=param.default is inspect.Parameter.empty, + default=default, + ) + ) + else: + # Obtain the type constraint from the type annotation + + # 1. Get a type constraint name from the type annotation + # If the type annotation is a TypeVar or Optional[TypeVar], get its name + # Otherwise, name it T_{param.name} + type_constraint_name = _get_type_constraint_name(type_) + if type_constraint_name is None: + type_constraint_name = f"T_{param.name}" + + # 2. If the type constraint param is already initialized, use it + if type_constraint_name in type_constraints: + type_constraint = type_constraints[type_constraint_name] + else: + # 3. Otherwise, create a new TypeConstraintParam + type_constraint = TypeConstraintParam( + name=type_constraint_name, + allowed_types=_get_allowed_types_from_type_annotation( + type_ + ), + ) + type_constraints[type_constraint_name] = type_constraint + # 4. Create Parameter + params.append( + Parameter( # type: ignore[arg-type] + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + + return_type = type_hints.get("return") + + outputs = [] + if return_type is None: + # No returns + pass + else: + if typing.get_origin(return_type) is tuple: + # Multiple returns + return_types = typing.get_args(return_type) + else: + return_types = [return_type] # type: ignore[assignment] + + for i, return_type_i in enumerate(return_types): + if ( + return_param_name := _get_type_constraint_name(return_type_i) + ) in type_constraints: + type_constraint = type_constraints[return_param_name] + else: + return_param_name = f"TReturn{i}" + type_constraint = TypeConstraintParam( + name=return_param_name, + allowed_types=_get_allowed_types_from_type_annotation( + return_type_i + ), + ) + type_constraints[return_param_name] = type_constraint + outputs.append( + Parameter( + name=return_param_name, + type_constraint=type_constraint, + required=True, + variadic=False, + default=_EMPTY_DEFAULT, + ) + ) + + return cls( + domain=domain, + name=name or func.__name__, + overload=overload, + params=params, + outputs=outputs, + ) diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py new file mode 100644 index 00000000000..cfe8f7dc2a6 --- /dev/null +++ b/torch/onnx/_internal/exporter/_tensors.py @@ -0,0 +1,98 @@ +"""Subclass of ir.Value that supports Python operators.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import onnxscript +from onnxscript import ir + + +class SymbolicTensor(ir.Value): + """A subclass of ir.Value that supports Python operators.""" + + def __init__( + self, + opset: onnxscript.values.Opset, + name: str | None = None, + shape: ir.Shape | None = None, + type: ir.TypeProtocol | None = None, + doc_string: str | None = None, + const_value: ir.TensorProtocol | None = None, + ): + super().__init__( + name=name, + shape=shape, + type=type, + doc_string=doc_string, + const_value=const_value, + ) + self._opset = opset + + @property + def rank(self) -> int | None: + if self.shape is None: + return None + return len(self.shape) + + # TODO: Implement indexing + + def __mod__(self, other): + if self.dtype in { + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + }: + return self._opset.Mod(self, other, fmod=1) + return self._opset.Mod(self, other) + + def __ne__(self, other): + return self._opset.Not(self._opset.Equal(self, other)) + + def __neg__(self): + return self._opset.Neg(self) + + def __add__(self, other): + return self._opset.Add(self, other) + + def __radd__(self, other): + return self._opset.Add(other, self) + + def __rand__(self, other): + return self._opset.And(other, self) + + def __mul__(self, other): + return self._opset.Mul(self, other) + + def __rmul__(self, other): + return self._opset.Mul(other, self) + + def __matmul__(self, other): + return self._opset.MatMul(self, other) + + def __pow__(self, other): + return self._opset.Pow(self, other) + + def __sub__(self, other): + return self._opset.Sub(self, other) + + def __rsub__(self, other): + return self._opset.Sub(other, self) + + def __truediv__(self, other): + return self._opset.Div(self, other) + + def __lt__(self, other): + return self._opset.Less(self, other) + + def __le__(self, other): + return self._opset.LessOrEqual(self, other) + + def __eq__(self, other): + return self._opset.Equal(self, other) + + def __ge__(self, other): + return self._opset.GreaterOrEqual(self, other) + + def __gt__(self, other): + return self._opset.Greater(self, other) diff --git a/torch/onnx/_internal/exporter/_verification.py b/torch/onnx/_internal/exporter/_verification.py new file mode 100644 index 00000000000..00822ca8991 --- /dev/null +++ b/torch/onnx/_internal/exporter/_verification.py @@ -0,0 +1,79 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +from typing import Any, TYPE_CHECKING + +import torch +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + from torch.onnx._internal.exporter import _onnx_program + + +@dataclasses.dataclass +class VerificationInfo: + name: str + absolute_difference: float + relative_difference: float + expected_dtype: torch.dtype + actual_dtype: torch.dtype + # NOTE: We don't need to include shape because the expected shape is already known + # and checked by the runtime + + +def _compare_tensors( + expected: torch.Tensor, + actual: torch.Tensor, +) -> tuple[float, float]: + # Move tensors to the same device + expected = expected.detach().cpu() + actual = actual.detach().cpu() + absolute_difference = torch.abs(expected - actual).max().item() + eps = 1e-7 + relative_difference = ( + (torch.abs(expected - actual) / (torch.abs(expected) + eps)).max().item() + ) + return absolute_difference, relative_difference + + +def verify_onnx_program( + onnx_program: _onnx_program.ONNXProgram, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, +) -> list[VerificationInfo]: + exported_program = onnx_program.exported_program + if args is None and kwargs is None: + # User did not provide example inputs, use the default example inputs + if exported_program.example_inputs is None: + raise ValueError( + "No example inputs provided and the exported_program does not contain example inputs. " + "Please provide arguments to verify the ONNX program." + ) + args, kwargs = exported_program.example_inputs + if args is None: + args = () + if kwargs is None: + kwargs = {} + torch_module = exported_program.module() + torch_outputs, _ = pytree.tree_flatten(torch_module(*args, **kwargs)) + onnx_outputs = onnx_program(*args, **kwargs) + results = [] + for torch_output, onnx_output, output_val in zip( + torch_outputs, onnx_outputs, onnx_program.model.graph.outputs + ): + name = output_val.name + absolute_difference, relative_difference = _compare_tensors( + torch_output, onnx_output + ) + results.append( + VerificationInfo( + name=str(name), + absolute_difference=absolute_difference, + relative_difference=relative_difference, + expected_dtype=torch_output.dtype, + actual_dtype=onnx_output.dtype, + ) + ) + return results diff --git a/torch/onnx/_internal/exporter/errors.py b/torch/onnx/_internal/exporter/errors.py new file mode 100644 index 00000000000..a70eccf3a56 --- /dev/null +++ b/torch/onnx/_internal/exporter/errors.py @@ -0,0 +1,30 @@ +class ExporterError(RuntimeError): + """Error during export.""" + + +class TorchExportError(ExporterError): + """Error during torch.export.export.""" + + +class OnnxConversionError(ExporterError): + """Error during ONNX conversion.""" + + +class DispatchError(OnnxConversionError): + """Error during ONNX Funtion dispatching.""" + + +class GraphConstructionError(OnnxConversionError): + """Error during graph construction.""" + + +class OnnxCheckerError(ExporterError): + """Error during ONNX model checking.""" + + +class OnnxRuntimeError(ExporterError): + """Error during ONNX Runtime execution.""" + + +class OnnxValidationError(ExporterError): + """Output value mismatch.""" diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 6a6526b18ee..c70a90e1e01 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -172,10 +172,7 @@ def _get_torch_export_args( def export( - model: torch.nn.Module - | torch.jit.ScriptModule - | torch.jit.ScriptFunction - | torch.export.ExportedProgram, + model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, args: tuple[Any, ...] | torch.Tensor, f: str | None = None, *, @@ -191,13 +188,11 @@ def export( dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None, - dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, keep_initializers_as_inputs: bool | None = None, custom_opsets: Mapping[str, int] | None = None, export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, - autograd_inlining: bool | None = True, - dynamo: bool = False, -) -> torch.onnx.ONNXProgram | None: + autograd_inlining: bool = True, +) -> None: r"""Exports a model into ONNX format. If ``model`` is not a :class:`torch.jit.ScriptModule` nor a @@ -491,8 +486,6 @@ def export( autograd_inlining: Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. - dynamo: Whether to export the model with Dynamo instead of TorchScript. - Raises: :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it @@ -515,65 +508,29 @@ def export( ) args = (args,) if isinstance(args, torch.Tensor) else args + if kwargs is not None: + args = args + (kwargs,) - if dynamo: - if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)): - raise TypeError( - "Dynamo export does not support ScriptModule or ScriptFunction." - ) - # TODO(justinchuby): Remove the warning once logic migration is done - warnings.warn( - "export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " - "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " - "autograd_inlining are not supported for dynamo export at the moment." - ) - args, kwargs = _get_torch_export_args(args, kwargs) - if isinstance(model, torch.export.ExportedProgram): - exported_program = model - else: - if dynamic_shapes is None and dynamic_axes is not None: - dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( - model, dynamic_axes, input_names - ) - exported_program = torch.export.export( - model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type] - ) - if kwargs is None: - # TODO(justinchuby): dynamo_export requires kwargs to be unpacked. Once migration is done - # we can pass kwargs as None - kwargs = {} - onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs) - if f is not None: - onnx_program.save(f) - return onnx_program + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) - else: - # Torch Script export path - if f is None: - raise ValueError("Export destination must be specified when dynamo=False.") - if kwargs is not None: - args = args + (kwargs,) - - _export( - model, - args, - f, - export_params, - verbose, - training, - input_names, - output_names, - operator_export_type=operator_export_type, - opset_version=opset_version, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - custom_opsets=custom_opsets, - export_modules_as_functions=export_modules_as_functions, - autograd_inlining=autograd_inlining, - ) - - return None + return None def _is_constant_tensor_list(node): @@ -1531,7 +1488,7 @@ def _export( custom_opsets=None, add_node_names=True, onnx_shape_inference=True, - export_modules_as_functions=False, + export_modules_as_functions: Any = False, autograd_inlining=True, ): assert GLOBALS.in_onnx_export is False @@ -1560,9 +1517,7 @@ def _export( f"Exporting to ONNX opset version {opset_version} is not supported. " f"by 'torch.onnx.export()'. " f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " - f"To use a newer opset version, consider 'torch.onnx.dynamo_export()'. " - f"Note that dynamo_export() is in preview. Please report errors with " - f"dynamo_export() as Github issues to https://github.com/pytorch/pytorch/issues.", + f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", category=errors.OnnxExporterWarning, )