Revert "Remove old NVTX interface (#167637)"

This reverts commit 99117c1238.

Reverted https://github.com/pytorch/pytorch/pull/167637 on behalf of https://github.com/yangw-dev due to breaks internal build with torch/csrc/profiler/stubs/cuda.cpp:4:10: fatal error: 'nvtx3/nvtx3.hpp' file not found 4 | #include <nvtx3/nvtx3.hpp>, please find a meta fella to resolve this issue and try again, diff:[D87229660] ([comment](https://github.com/pytorch/pytorch/pull/167637#issuecomment-3543984021))
This commit is contained in:
PyTorch MergeBot
2025-11-17 21:51:04 +00:00
parent 4e1b772103
commit 661fb53449
6 changed files with 30 additions and 2 deletions

View File

@@ -1643,6 +1643,8 @@ if(USE_CUDA)
target_link_libraries(torch_cuda PUBLIC c10_cuda)
if(TARGET torch::nvtx3)
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
else()
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
endif()
target_include_directories(
@@ -1739,6 +1741,9 @@ if(BUILD_SHARED_LIBS)
if(USE_CUDA)
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
target_link_libraries(torch_global_deps torch::cudart)
if(TARGET torch::nvtoolsext)
target_link_libraries(torch_global_deps torch::nvtoolsext)
endif()
endif()
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

View File

@@ -968,8 +968,11 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
if(nvtx3_FOUND)
add_library(torch::nvtx3 INTERFACE IMPORTED)
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
else()
message(FATAL_ERROR "Cannot find NVTX3!")
message(WARNING "Cannot find NVTX3, find old NVTX instead")
add_library(torch::nvtoolsext INTERFACE IMPORTED)
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
endif()

View File

@@ -132,6 +132,9 @@ if(@USE_CUDA@)
else()
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
endif()
if(@BUILD_SHARED_LIBS@)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")

View File

@@ -150,6 +150,10 @@ if(USE_CUDA)
if(TARGET torch::nvtx3)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
else()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtoolsext)
endif()
endif()
endif()

View File

@@ -2,13 +2,18 @@
#include <wchar.h> // _wgetenv for nvtx
#endif
#include <cuda_runtime.h>
#ifndef ROCM_ON_WINDOWS
#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvtx3/nvtx3.hpp>
#else // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvToolsExt.h>
#endif // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::cuda::shared {
@@ -50,7 +55,11 @@ static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
void initNvtxBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#ifdef TORCH_CUDA_USE_NVTX3
auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings");
#else
auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
#endif
nvtx.def("rangePushA", nvtxRangePushA);
nvtx.def("rangePop", nvtxRangePop);
nvtx.def("rangeStartA", nvtxRangeStartA);

View File

@@ -1,7 +1,11 @@
#include <sstream>
#ifndef ROCM_ON_WINDOWS
#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvtx3/nvtx3.hpp>
#else
#include <nvToolsExt.h>
#endif
#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS