diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9af0305778d..d5c585c1e1f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1643,6 +1643,8 @@ if(USE_CUDA) target_link_libraries(torch_cuda PUBLIC c10_cuda) if(TARGET torch::nvtx3) target_link_libraries(torch_cuda PRIVATE torch::nvtx3) + else() + target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext) endif() target_include_directories( @@ -1739,6 +1741,9 @@ if(BUILD_SHARED_LIBS) if(USE_CUDA) target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) target_link_libraries(torch_global_deps torch::cudart) + if(TARGET torch::nvtoolsext) + target_link_libraries(torch_global_deps torch::nvtoolsext) + endif() endif() install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 444a7590a8a..733183ef50b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -968,8 +968,11 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir) if(nvtx3_FOUND) add_library(torch::nvtx3 INTERFACE IMPORTED) target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}") + target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3) else() - message(FATAL_ERROR "Cannot find NVTX3!") + message(WARNING "Cannot find NVTX3, find old NVTX instead") + add_library(torch::nvtoolsext INTERFACE IMPORTED) + set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt) endif() diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 8a5587cad27..0b32ffa99ce 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -132,6 +132,9 @@ if(@USE_CUDA@) else() set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB}) endif() + if(TARGET torch::nvtoolsext) + list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext) + endif() if(@BUILD_SHARED_LIBS@) find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4e657201806..d92b9e19a76 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -150,6 +150,10 @@ if(USE_CUDA) if(TARGET torch::nvtx3) list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3) + else() + if(TARGET torch::nvtoolsext) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtoolsext) + endif() endif() endif() diff --git a/torch/csrc/cuda/shared/nvtx.cpp b/torch/csrc/cuda/shared/nvtx.cpp index 8faf319071c..f4b3c8824b8 100644 --- a/torch/csrc/cuda/shared/nvtx.cpp +++ b/torch/csrc/cuda/shared/nvtx.cpp @@ -2,13 +2,18 @@ #include // _wgetenv for nvtx #endif +#include + #ifndef ROCM_ON_WINDOWS +#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3) #include +#else // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3) +#include +#endif // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3) #else // ROCM_ON_WINDOWS #include #endif // ROCM_ON_WINDOWS #include -#include #include namespace torch::cuda::shared { @@ -50,7 +55,11 @@ static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) { void initNvtxBindings(PyObject* module) { auto m = py::handle(module).cast(); +#ifdef TORCH_CUDA_USE_NVTX3 auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings"); +#else + auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings"); +#endif nvtx.def("rangePushA", nvtxRangePushA); nvtx.def("rangePop", nvtxRangePop); nvtx.def("rangeStartA", nvtxRangeStartA); diff --git a/torch/csrc/profiler/stubs/cuda.cpp b/torch/csrc/profiler/stubs/cuda.cpp index b590b2d985d..2b634b0303c 100644 --- a/torch/csrc/profiler/stubs/cuda.cpp +++ b/torch/csrc/profiler/stubs/cuda.cpp @@ -1,7 +1,11 @@ #include #ifndef ROCM_ON_WINDOWS +#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3) #include +#else +#include +#endif #else // ROCM_ON_WINDOWS #include #endif // ROCM_ON_WINDOWS