mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[ONNX] Run ONNX tests as part of standard run_test script (#99215)
<!-- copilot:all --> ### <samp>🤖 Generated by Copilot at dcbf7e2</samp> ### Summary 📝🧹🚩 <!-- 1. 📝 for simplifying the `./scripts/onnx/test.sh` script 2. 🧹 for refactoring the `test/onnx/dynamo/test_exporter_api.py` file 3. 🚩 for adding the `--onnx` flag to `test/run_test.py` and updating the `TESTS` list --> This pull request improves the ONNX testing infrastructure in PyTorch by refactoring the test code, normalizing the scope names, adding a flag to run only the ONNX tests, and simplifying the test script. > _To export PyTorch models to ONNX_ > _We refactored some scripts and contexts_ > _We used `common_utils`_ > _And normalized the scopes_ > _And added a flag to run the tests_ ### Walkthrough * Simplify `./scripts/onnx/test.sh` to use `run_test.py` with `--onnx` flag instead of `pytest` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-0017f5b22ae1329acb0f54af8d9811c9b6180a72dac70d7a5b89d7c23c958198L44-R46)) * Remove `onnx` test from `TESTS` list in `test/run_test.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7L127-R127)). Replace with `onnx_caffe2`. * Add `onnx/test_pytorch_onnx_onnxruntime_cuda` and `onnx/test_models` tests to `blocklisted_tests` list in `test/run_test.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R154-R155)) * Add `ONNX_SERIAL_LIST` list to `test/run_test.py` to specify ONNX tests that must run serially ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R296-R301)) * Add `ONNX_TESTS` list to `test/run_test.py` to store all ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R370)) * Add `--onnx` flag to `parse_args` function in `test/run_test.py` to run only ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R920-R928)) * Include `ONNX_SERIAL_LIST` in `must_serial` function in `test/run_test.py` to run ONNX tests serially or parallelly based on memory usage ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1120)) * Filter selected tests based on `--onnx` flag in `get_selected_tests` function in `test/run_test.py` to exclude non-ONNX tests ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-e72503c9e3e8766e2d1bacf3fad7b88aa166e0e90a7e103e7df99357a35df8d7R1158-R1165)) ### Other minor changes to accommodate this change * Replace `unittest` module with `common_utils.TestCase` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L4), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L29-R28), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L71-R70), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L147-R146)) * Import `TemporaryFileName` class from `common_utils` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L19-R18)) * Use `common_utils.TemporaryFileName` instead of `TemporaryFileName` in `TestDynamoExportAPI` class in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L92-R91), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L110-R109), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L129-R128)) * Use `common_utils.run_tests` instead of `unittest.main` in `test/onnx/dynamo/test_exporter_api.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L155-R154)) * Add `re` module to `test/onnx/test_utility_funs.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7R6)) * Add `_remove_test_environment_prefix_from_scope_name` function to `test/onnx/test_utility_funs.py` to normalize scope names of ONNX nodes ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7R32-R58)) * Use `_remove_test_environment_prefix_from_scope_name` function to compare scope names of ONNX nodes in `TestUtilityFuns` class in `test/onnx/test_utility_funs.py` ([link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1099-R1133), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1119-R1152), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1170-R1188), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1181-R1199), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1220-R1239), [link](https://github.com/pytorch/pytorch/pull/99215/files?diff=unified&w=0#diff-da71d2c81c9dc7ac0c47ff086fded82e4edcb67ba0cd3d8b5c983d7467343bc7L1235-R1258)) Fixes #98626 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99215 Approved by: https://github.com/huydhn, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e69879209
commit
d41aa448b8
@@ -41,46 +41,9 @@ args+=("--cov-report")
|
||||
args+=("xml:test/coverage.xml")
|
||||
args+=("--cov-append")
|
||||
|
||||
args_parallel=()
|
||||
if [[ $PARALLEL == 1 ]]; then
|
||||
args_parallel+=("-n")
|
||||
args_parallel+=("auto")
|
||||
fi
|
||||
|
||||
# onnxruntime only support py3
|
||||
# "Python.h" not found in py2, needed by TorchScript custom op compilation.
|
||||
if [[ "${SHARD_NUMBER}" == "1" ]]; then
|
||||
# These exclusions are for tests that take a long time / a lot of GPU
|
||||
# memory to run; they should be passing (and you will test them if you
|
||||
# run them locally
|
||||
pytest "${args[@]}" "${args_parallel[@]}" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
|
||||
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
|
||||
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
||||
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
||||
--ignore "$top_dir/test/onnx/test_models.py" \
|
||||
--ignore "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
|
||||
"${test_paths[@]}"
|
||||
|
||||
# Heavy memory usage tests that cannot run in parallel.
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_custom_ops.py" \
|
||||
"$top_dir/test/onnx/test_utility_funs.py" \
|
||||
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "not TestModelsONNXRuntime"
|
||||
fi
|
||||
|
||||
if [[ "${SHARD_NUMBER}" == "2" ]]; then
|
||||
# Heavy memory usage tests that cannot run in parallel.
|
||||
# TODO(#79802): Parameterize test_models.py
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_models.py" \
|
||||
"$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
|
||||
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "TestModelsONNXRuntime"
|
||||
|
||||
pytest "${args[@]}" "${args_parallel[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
|
||||
time python "${top_dir}/test/run_test.py" --onnx --shard "$SHARD_NUMBER" 2 --verbose
|
||||
|
||||
if [[ "$SHARD_NUMBER" == "2" ]]; then
|
||||
# xdoctests on onnx
|
||||
xdoctest torch.onnx --style=google --options="+IGNORE_WHITESPACE"
|
||||
fi
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
import io
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
|
||||
@@ -16,7 +15,7 @@ from torch.onnx._internal.exporter import (
|
||||
ResolvedExportOptions,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class SampleModel(torch.nn.Module):
|
||||
@@ -26,7 +25,7 @@ class SampleModel(torch.nn.Module):
|
||||
return (y, z)
|
||||
|
||||
|
||||
class TestExportOptionsAPI(unittest.TestCase):
|
||||
class TestExportOptionsAPI(common_utils.TestCase):
|
||||
def test_opset_version_default(self):
|
||||
options = ResolvedExportOptions(None)
|
||||
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)
|
||||
@@ -68,7 +67,7 @@ class TestExportOptionsAPI(unittest.TestCase):
|
||||
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))
|
||||
|
||||
|
||||
class TestDynamoExportAPI(unittest.TestCase):
|
||||
class TestDynamoExportAPI(common_utils.TestCase):
|
||||
def test_default_export(self):
|
||||
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
|
||||
self.assertIsInstance(output, ExportOutput)
|
||||
@@ -89,7 +88,7 @@ class TestDynamoExportAPI(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_save_to_file_default_serializer(self):
|
||||
with TemporaryFileName() as path:
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
|
||||
onnx.load(path)
|
||||
|
||||
@@ -107,7 +106,7 @@ class TestDynamoExportAPI(unittest.TestCase):
|
||||
) -> None:
|
||||
destination.write(expected_buffer.encode())
|
||||
|
||||
with TemporaryFileName() as path:
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
|
||||
path, serializer=CustomSerializer()
|
||||
)
|
||||
@@ -126,7 +125,7 @@ class TestDynamoExportAPI(unittest.TestCase):
|
||||
) -> None:
|
||||
destination.write(expected_buffer.encode())
|
||||
|
||||
with TemporaryFileName() as path:
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
|
||||
path, serializer=CustomSerializer()
|
||||
)
|
||||
@@ -144,7 +143,7 @@ class TestDynamoExportAPI(unittest.TestCase):
|
||||
export_output.model_proto
|
||||
|
||||
|
||||
class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
|
||||
class TestProtobufExportOutputSerializerAPI(common_utils.TestCase):
|
||||
def test_raise_on_invalid_argument_type(self):
|
||||
with self.assertRaises(roar.BeartypeException):
|
||||
serializer = ProtobufExportOutputSerializer()
|
||||
@@ -152,4 +151,4 @@ class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
common_utils.run_tests()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import copy
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
@@ -28,6 +29,33 @@ from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
|
||||
from verify import verify
|
||||
|
||||
|
||||
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
|
||||
"""Remove test environment prefix added to module.
|
||||
|
||||
Remove prefix to normalize scope names, since different test environments add
|
||||
prefixes with slight differences.
|
||||
|
||||
Example:
|
||||
|
||||
>>> _remove_test_environment_prefix_from_scope_name(
|
||||
>>> "test_utility_funs.M"
|
||||
>>> )
|
||||
"M"
|
||||
>>> _remove_test_environment_prefix_from_scope_name(
|
||||
>>> "test_utility_funs.test_abc.<locals>.M"
|
||||
>>> )
|
||||
"M"
|
||||
>>> _remove_test_environment_prefix_from_scope_name(
|
||||
>>> "__main__.M"
|
||||
>>> )
|
||||
"M"
|
||||
"""
|
||||
prefixes_to_remove = ["test_utility_funs", "__main__"]
|
||||
for prefix in prefixes_to_remove:
|
||||
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
|
||||
return scope_name
|
||||
|
||||
|
||||
class _BaseTestCase(pytorch_test_common.ExportTestCase):
|
||||
def _model_to_graph(
|
||||
self,
|
||||
@@ -1096,43 +1124,32 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
|
||||
model = M(3)
|
||||
expected_scope_names = {
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"torch.nn.modules.activation.GELU::gelu1",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"torch.nn.modules.activation.GELU::gelu2",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"torch.nn.modules.normalization.LayerNorm::lns.0",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"torch.nn.modules.normalization.LayerNorm::lns.1",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"torch.nn.modules.normalization.LayerNorm::lns.2",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.N::relu/"
|
||||
"torch.nn.modules.activation.ReLU::relu",
|
||||
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::",
|
||||
"M::/torch.nn.modules.activation.GELU::gelu1",
|
||||
"M::/torch.nn.modules.activation.GELU::gelu2",
|
||||
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
|
||||
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
|
||||
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
|
||||
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
|
||||
"M::",
|
||||
}
|
||||
|
||||
graph, _, _ = self._model_to_graph(
|
||||
model, (x, y, z), input_names=[], dynamic_axes={}
|
||||
)
|
||||
for node in graph.nodes():
|
||||
self.assertIn(node.scopeName(), expected_scope_names)
|
||||
|
||||
expected_torch_script_scope_names = {
|
||||
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu1",
|
||||
"test_utility_funs.M::/torch.nn.modules.activation.GELU::gelu2",
|
||||
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.0",
|
||||
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.1",
|
||||
"test_utility_funs.M::/torch.nn.modules.normalization.LayerNorm::lns.2",
|
||||
"test_utility_funs.M::/test_utility_funs.N::relu/torch.nn.modules.activation.ReLU::relu",
|
||||
"test_utility_funs.M::",
|
||||
}
|
||||
self.assertIn(
|
||||
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
|
||||
expected_scope_names,
|
||||
)
|
||||
|
||||
graph, _, _ = self._model_to_graph(
|
||||
torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
|
||||
)
|
||||
for node in graph.nodes():
|
||||
self.assertIn(node.scopeName(), expected_torch_script_scope_names)
|
||||
self.assertIn(
|
||||
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
|
||||
expected_scope_names,
|
||||
)
|
||||
|
||||
def test_scope_of_constants_when_combined_by_cse_pass(self):
|
||||
layer_num = 3
|
||||
@@ -1167,9 +1184,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
|
||||
# If CSE in exporter is improved later, this test needs to be updated.
|
||||
# It should expect 1 constant, with same scope as root.
|
||||
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
|
||||
expected_root_scope_name = f"{scope_prefix}.N::"
|
||||
expected_layer_scope_name = f"{scope_prefix}.M::layers"
|
||||
expected_root_scope_name = "N::"
|
||||
expected_layer_scope_name = "M::layers"
|
||||
expected_constant_scope_name = [
|
||||
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
|
||||
for i in range(layer_num)
|
||||
@@ -1178,7 +1194,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
constant_scope_names = []
|
||||
for node in graph.nodes():
|
||||
if node.kind() == "onnx::Constant":
|
||||
constant_scope_names.append(node.scopeName())
|
||||
constant_scope_names.append(
|
||||
_remove_test_environment_prefix_from_scope_name(node.scopeName())
|
||||
)
|
||||
self.assertEqual(constant_scope_names, expected_constant_scope_name)
|
||||
|
||||
def test_scope_of_nodes_when_combined_by_cse_pass(self):
|
||||
@@ -1217,9 +1235,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
graph, _, _ = self._model_to_graph(
|
||||
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
|
||||
)
|
||||
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
|
||||
expected_root_scope_name = f"{scope_prefix}.N::"
|
||||
expected_layer_scope_name = f"{scope_prefix}.M::layers"
|
||||
expected_root_scope_name = "N::"
|
||||
expected_layer_scope_name = "M::layers"
|
||||
expected_add_scope_names = [
|
||||
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
|
||||
]
|
||||
@@ -1232,9 +1249,13 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
mul_scope_names = []
|
||||
for node in graph.nodes():
|
||||
if node.kind() == "onnx::Add":
|
||||
add_scope_names.append(node.scopeName())
|
||||
add_scope_names.append(
|
||||
_remove_test_environment_prefix_from_scope_name(node.scopeName())
|
||||
)
|
||||
elif node.kind() == "onnx::Mul":
|
||||
mul_scope_names.append(node.scopeName())
|
||||
mul_scope_names.append(
|
||||
_remove_test_environment_prefix_from_scope_name(node.scopeName())
|
||||
)
|
||||
self.assertEqual(add_scope_names, expected_add_scope_names)
|
||||
self.assertEqual(mul_scope_names, expected_mul_scope_names)
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ TESTS = discover_tests(
|
||||
"fx", # executed by test_fx.py
|
||||
"jit", # executed by test_jit.py
|
||||
"mobile",
|
||||
"onnx",
|
||||
"onnx_caffe2",
|
||||
"package", # executed by test_package.py
|
||||
"quantization", # executed by test_quantization.py
|
||||
"autograd", # executed by test_autograd.py
|
||||
@@ -151,6 +151,8 @@ TESTS = discover_tests(
|
||||
"distributed/test_c10d_spawn",
|
||||
"distributions/test_transforms",
|
||||
"distributions/test_utils",
|
||||
"onnx/test_pytorch_onnx_onnxruntime_cuda",
|
||||
"onnx/test_models",
|
||||
],
|
||||
extra_tests=[
|
||||
"test_cpp_extensions_aot_ninja",
|
||||
@@ -292,6 +294,14 @@ CI_SERIAL_LIST = [
|
||||
"test_native_mha", # OOM
|
||||
"test_module_hooks", # OOM
|
||||
]
|
||||
# A subset of onnx tests that cannot run in parallel due to high memory usage.
|
||||
ONNX_SERIAL_LIST = [
|
||||
"onnx/test_models",
|
||||
"onnx/test_models_quantized_onnxruntime",
|
||||
"onnx/test_models_onnxruntime",
|
||||
"onnx/test_custom_ops",
|
||||
"onnx/test_utility_funs",
|
||||
]
|
||||
|
||||
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
|
||||
CORE_TEST_LIST = [
|
||||
@@ -360,6 +370,7 @@ JIT_EXECUTOR_TESTS = [
|
||||
|
||||
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith("distributed")]
|
||||
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
|
||||
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
|
||||
|
||||
TESTS_REQUIRING_LAPACK = [
|
||||
"distributions/test_constraints",
|
||||
@@ -908,6 +919,15 @@ def parse_args():
|
||||
help="Only run core tests, or tests that validate PyTorch's ops, modules,"
|
||||
"and autograd. They are defined by CORE_TEST_LIST.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
"--onnx",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
|
||||
"If this flag is not present, we will exclude ONNX tests."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pt",
|
||||
"--pytest",
|
||||
@@ -1100,6 +1120,7 @@ def must_serial(file: str) -> bool:
|
||||
or file in RUN_PARALLEL_BLOCKLIST
|
||||
or file in CI_SERIAL_LIST
|
||||
or file in JIT_EXECUTOR_TESTS
|
||||
or file in ONNX_SERIAL_LIST
|
||||
)
|
||||
|
||||
|
||||
@@ -1137,6 +1158,14 @@ def get_selected_tests(options):
|
||||
# Exclude all mps tests otherwise
|
||||
options.exclude.extend(["test_mps", "test_metal"])
|
||||
|
||||
# Filter to only run onnx tests when --onnx option is specified
|
||||
onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
|
||||
if options.onnx:
|
||||
selected_tests = onnx_tests
|
||||
else:
|
||||
# Exclude all onnx tests otherwise
|
||||
options.exclude.extend(onnx_tests)
|
||||
|
||||
# process reordering
|
||||
if options.bring_to_front:
|
||||
to_front = set(options.bring_to_front)
|
||||
|
||||
Reference in New Issue
Block a user