diff --git a/test/cpp_extensions/open_registration_extension/custom_device/csrc/extension.cpp b/test/cpp_extensions/open_registration_extension.cpp similarity index 100% rename from test/cpp_extensions/open_registration_extension/custom_device/csrc/extension.cpp rename to test/cpp_extensions/open_registration_extension.cpp diff --git a/test/cpp_extensions/open_registration_extension/setup.py b/test/cpp_extensions/open_registration_extension/setup.py index bf1de5b9a8d..fa8c1308c6c 100644 --- a/test/cpp_extensions/open_registration_extension/setup.py +++ b/test/cpp_extensions/open_registration_extension/setup.py @@ -1,5 +1,6 @@ import distutils.command.clean import os +import platform import shutil import sys from pathlib import Path @@ -40,8 +41,11 @@ if __name__ == "__main__": CXX_FLAGS = ["/sdl"] else: CXX_FLAGS = ["/sdl", "/permissive-"] - else: + elif platform.machine() == "s390x": + # no -Werror on s390x due to newer compiler CXX_FLAGS = {"cxx": ["-g", "-Wall"]} + else: + CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]} sources = list(CSRS_DIR.glob("*.cpp")) diff --git a/test/run_test.py b/test/run_test.py index baf915942a5..f2cc18d5139 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1236,6 +1236,7 @@ CUSTOM_HANDLERS = { "test_ci_sanity_check_fail": run_ci_sanity_check, "test_autoload_enable": test_autoload_enable, "test_autoload_disable": test_autoload_disable, + "test_cpp_extensions_open_device_registration": run_test_with_openreg, "test_openreg": run_test_with_openreg, "test_transformers_privateuse1": run_test_with_openreg, } diff --git a/test/cpp_extensions/open_registration_extension/test/test_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py similarity index 98% rename from test/cpp_extensions/open_registration_extension/test/test_open_device_registration.py rename to test/test_cpp_extensions_open_device_registration.py index c17fb91c559..4ec01bf6fce 100644 --- a/test/cpp_extensions/open_registration_extension/test/test_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -1,16 +1,15 @@ # Owner(s): ["module: cpp-extensions"] import _codecs -import importlib import io import os import sys import tempfile import unittest -from pathlib import Path from unittest.mock import patch import numpy as np +import pytorch_openreg # noqa: F401 import torch import torch.testing._internal.common_utils as common @@ -61,21 +60,18 @@ class TestCppExtensionOpenRgistration(common.TestCase): @classmethod def setUpClass(cls): - # load custom device extension - extension_root = Path(__file__).parent.parent + torch.testing._internal.common_utils.remove_cpp_extensions_build_root() + cls.module = torch.utils.cpp_extension.load( name="custom_device_extension", sources=[ - f"{extension_root}/custom_device/csrc/extension.cpp", + "cpp_extensions/open_registration_extension.cpp", ], - extra_include_paths=[], + extra_include_paths=["cpp_extensions"], extra_cflags=["-g"], verbose=True, ) - # install / load pytorch_openreg extension - common.install_cpp_extension(extension_root=extension_root) - globals()["pytorch_openreg"] = importlib.import_module("pytorch_openreg") torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) def test_base_device_registration(self):