[ONNX] Run ONNX tests as part of standard run_test script (#99215)

<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at dcbf7e2</samp>

### Summary
📝🧹🚩

<!--
1.  📝 for simplifying the `./scripts/onnx/test.sh` script
2.  🧹 for refactoring the `test/onnx/dynamo/test_exporter_api.py` file
3.  🚩 for adding the `--onnx` flag to `test/run_test.py` and updating the `TESTS` list
-->
This pull request improves the ONNX testing infrastructure in PyTorch by refactoring the test code, normalizing the scope names, adding a flag to run only the ONNX tests, and simplifying the test script.

> _To export PyTorch models to ONNX_
> _We refactored some scripts and contexts_
> _We used `common_utils`_
> _And normalized the scopes_
> _And added a flag to run the tests_

### Walkthrough
*  Simplify `./scripts/onnx/test.sh` to use `run_test.py` with `--onnx` flag instead of `pytest` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-0017f5b22ae1329acb0f54af8d9811c9b6180a72dac70d7a5b89d7c23c958198L44-R46))
*  Remove `onnx` test from `TESTS` list in `test/run_test.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7L127-R127)). Replace with `onnx_caffe2`.
*  Add `onnx/test_pytorch_onnx_onnxruntime_cuda` and `onnx/test_models` tests to `blocklisted_tests` list in `test/run_test.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R154-R155))
*  Add `ONNX_SERIAL_LIST` list to `test/run_test.py` to specify ONNX tests that must run serially ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R296-R301))
*  Add `ONNX_TESTS` list to `test/run_test.py` to store all ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R370))
*  Add `--onnx` flag to `parse_args` function in `test/run_test.py` to run only ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R920-R928))
*  Include `ONNX_SERIAL_LIST` in `must_serial` function in `test/run_test.py` to run ONNX tests serially or parallelly based on memory usage ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1120))
*  Filter selected tests based on `--onnx` flag in `get_selected_tests` function in `test/run_test.py` to exclude non-ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1158-R1165))

### Other minor changes to accommodate this change
*  Replace `unittest` module with `common_utils.TestCase` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L4), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L29-R28), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L71-R70), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L147-R146))
*  Import `TemporaryFileName` class from `common_utils` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L19-R18))
*  Use `common_utils.TemporaryFileName` instead of `TemporaryFileName` in `TestDynamoExportAPI` class in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L92-R91), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L110-R109), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L129-R128))
*  Use `common_utils.run_tests` instead of `unittest.main` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L155-R154))
*  Add `re` module to `test/onnx/test_utility_funs.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7R6))
*  Add `_remove_test_environment_prefix_from_scope_name` function to `test/onnx/test_utility_funs.py` to normalize scope names of ONNX nodes ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7R32-R58))
*  Use `_remove_test_environment_prefix_from_scope_name` function to compare scope names of ONNX nodes in `TestUtilityFuns` class in `test/onnx/test_utility_funs.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1099-R1133), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1119-R1152), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1170-R1188), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1181-R1199), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1220-R1239), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1235-R1258))

Fixes #98626

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99215
Approved by: https://github.com/huydhn, https://github.com/titaiwangms
This commit is contained in:
BowenBao
2023-04-17 17:49:08 -07:00
committed by PyTorch MergeBot
parent 8e69879209
commit d41aa448b8
4 changed files with 96 additions and 84 deletions

View File

@@ -41,46 +41,9 @@ args+=("--cov-report")
args+=("xml:test/coverage.xml")
args+=("--cov-append")
args_parallel=()
if [[ $PARALLEL == 1 ]]; then
args_parallel+=("-n")
args_parallel+=("auto")
fi
# onnxruntime only support py3
# "Python.h" not found in py2, needed by TorchScript custom op compilation.
if [[ "${SHARD_NUMBER}" == "1" ]]; then
# These exclusions are for tests that take a long time / a lot of GPU
# memory to run; they should be passing (and you will test them if you
# run them locally
pytest "${args[@]}" "${args_parallel[@]}" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_models.py" \
--ignore "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
"${test_paths[@]}"
# Heavy memory usage tests that cannot run in parallel.
pytest "${args[@]}" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_utility_funs.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "not TestModelsONNXRuntime"
fi
if [[ "${SHARD_NUMBER}" == "2" ]]; then
# Heavy memory usage tests that cannot run in parallel.
# TODO(#79802): Parameterize test_models.py
pytest "${args[@]}" \
"$top_dir/test/onnx/test_models.py" \
"$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "TestModelsONNXRuntime"
pytest "${args[@]}" "${args_parallel[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
time python "${top_dir}/test/run_test.py" --onnx --shard "$SHARD_NUMBER" 2 --verbose
if [[ "$SHARD_NUMBER" == "2" ]]; then
# xdoctests on onnx
xdoctest torch.onnx --style=google --options="+IGNORE_WHITESPACE"
fi

View File

@@ -1,7 +1,6 @@
# Owner(s): ["module: onnx"]
import io
import logging
import unittest
import onnx
@@ -16,7 +15,7 @@ from torch.onnx._internal.exporter import (
ResolvedExportOptions,
)
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal import common_utils
class SampleModel(torch.nn.Module):
@@ -26,7 +25,7 @@ class SampleModel(torch.nn.Module):
return (y, z)
class TestExportOptionsAPI(unittest.TestCase):
class TestExportOptionsAPI(common_utils.TestCase):
def test_opset_version_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)
@@ -68,7 +67,7 @@ class TestExportOptionsAPI(unittest.TestCase):
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))
class TestDynamoExportAPI(unittest.TestCase):
class TestDynamoExportAPI(common_utils.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
@@ -89,7 +88,7 @@ class TestDynamoExportAPI(unittest.TestCase):
)
def test_save_to_file_default_serializer(self):
with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)
@@ -107,7 +106,7 @@ class TestDynamoExportAPI(unittest.TestCase):
) -> None:
destination.write(expected_buffer.encode())
with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
@@ -126,7 +125,7 @@ class TestDynamoExportAPI(unittest.TestCase):
) -> None:
destination.write(expected_buffer.encode())
with TemporaryFileName() as path:
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
@@ -144,7 +143,7 @@ class TestDynamoExportAPI(unittest.TestCase):
export_output.model_proto
class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
class TestProtobufExportOutputSerializerAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
@@ -152,4 +151,4 @@ class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
common_utils.run_tests()

View File

@@ -3,6 +3,7 @@
import copy
import functools
import io
import re
import warnings
from typing import Callable
@@ -28,6 +29,33 @@ from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
from verify import verify
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
"""Remove test environment prefix added to module.
Remove prefix to normalize scope names, since different test environments add
prefixes with slight differences.
Example:
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.test_abc.<locals>.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "__main__.M"
>>> )
"M"
"""
prefixes_to_remove = ["test_utility_funs", "__main__"]
for prefix in prefixes_to_remove:
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
return scope_name
class _BaseTestCase(pytorch_test_common.ExportTestCase):
def _model_to_graph(
self,
@@ -1096,43 +1124,32 @@ class TestUtilityFuns(_BaseTestCase):
model = M(3)
expected_scope_names = {
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.N::relu/"
"torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::",
"M::/torch.nn.modules.activation.GELU::gelu1",
"M::/torch.nn.modules.activation.GELU::gelu2",
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
"M::",
}
graph, _, _ = self._model_to_graph(
model, (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(node.scopeName(), expected_scope_names)
expected_torch_script_scope_names = {
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.M::/test_utility_funs.N::relu/torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.M::",
}
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)
graph, _, _ = self._model_to_graph(
torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(node.scopeName(), expected_torch_script_scope_names)
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)
def test_scope_of_constants_when_combined_by_cse_pass(self):
layer_num = 3
@@ -1167,9 +1184,8 @@ class TestUtilityFuns(_BaseTestCase):
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_constant_scope_name = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
for i in range(layer_num)
@@ -1178,7 +1194,9 @@ class TestUtilityFuns(_BaseTestCase):
constant_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Constant":
constant_scope_names.append(node.scopeName())
constant_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(constant_scope_names, expected_constant_scope_name)
def test_scope_of_nodes_when_combined_by_cse_pass(self):
@@ -1217,9 +1235,8 @@ class TestUtilityFuns(_BaseTestCase):
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_add_scope_names = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
]
@@ -1232,9 +1249,13 @@ class TestUtilityFuns(_BaseTestCase):
mul_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Add":
add_scope_names.append(node.scopeName())
add_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
elif node.kind() == "onnx::Mul":
mul_scope_names.append(node.scopeName())
mul_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(add_scope_names, expected_add_scope_names)
self.assertEqual(mul_scope_names, expected_mul_scope_names)

View File

@@ -124,7 +124,7 @@ TESTS = discover_tests(
"fx", # executed by test_fx.py
"jit", # executed by test_jit.py
"mobile",
"onnx",
"onnx_caffe2",
"package", # executed by test_package.py
"quantization", # executed by test_quantization.py
"autograd", # executed by test_autograd.py
@@ -151,6 +151,8 @@ TESTS = discover_tests(
"distributed/test_c10d_spawn",
"distributions/test_transforms",
"distributions/test_utils",
"onnx/test_pytorch_onnx_onnxruntime_cuda",
"onnx/test_models",
],
extra_tests=[
"test_cpp_extensions_aot_ninja",
@@ -292,6 +294,14 @@ CI_SERIAL_LIST = [
"test_native_mha", # OOM
"test_module_hooks", # OOM
]
# A subset of onnx tests that cannot run in parallel due to high memory usage.
ONNX_SERIAL_LIST = [
"onnx/test_models",
"onnx/test_models_quantized_onnxruntime",
"onnx/test_models_onnxruntime",
"onnx/test_custom_ops",
"onnx/test_utility_funs",
]
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
CORE_TEST_LIST = [
@@ -360,6 +370,7 @@ JIT_EXECUTOR_TESTS = [
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith("distributed")]
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
TESTS_REQUIRING_LAPACK = [
"distributions/test_constraints",
@@ -908,6 +919,15 @@ def parse_args():
help="Only run core tests, or tests that validate PyTorch's ops, modules,"
"and autograd. They are defined by CORE_TEST_LIST.",
)
parser.add_argument(
"--onnx",
"--onnx",
action="store_true",
help=(
"Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
"If this flag is not present, we will exclude ONNX tests."
),
)
parser.add_argument(
"-pt",
"--pytest",
@@ -1100,6 +1120,7 @@ def must_serial(file: str) -> bool:
or file in RUN_PARALLEL_BLOCKLIST
or file in CI_SERIAL_LIST
or file in JIT_EXECUTOR_TESTS
or file in ONNX_SERIAL_LIST
)
@@ -1137,6 +1158,14 @@ def get_selected_tests(options):
# Exclude all mps tests otherwise
options.exclude.extend(["test_mps", "test_metal"])
# Filter to only run onnx tests when --onnx option is specified
onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
if options.onnx:
selected_tests = onnx_tests
else:
# Exclude all onnx tests otherwise
options.exclude.extend(onnx_tests)
# process reordering
if options.bring_to_front:
to_front = set(options.bring_to_front)