diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py new file mode 100644 index 00000000000..14afb7b7bfa --- /dev/null +++ b/test/inductor/test_profiler.py @@ -0,0 +1,51 @@ +# Owner(s): ["module: inductor"] +import json +import unittest + +import torch +import torch._dynamo.test_case +import torch._inductor.utils + +from torch.testing._internal.common_utils import TemporaryFileName + +HAS_TRITON = torch._inductor.utils.has_triton() + + +class DynamoProfilerTests(torch._dynamo.test_case.TestCase): + @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") + def test_inductor_profiling_triton_launch(self): + # Verify that we get some sort of CPU-side indication of triton kernel launches + # in the profile traces. Currently, those appear as `cuLaunchKernel`. If this + # detail changes, the test can be updated or removed. + @torch.compile + def fn(x, y): + return (x + y).sin().cos() + + x, y = [torch.rand((4, 4), device="cuda") for _ in range(2)] + + with torch.profiler.profile() as prof: + fn(x, y) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname, "r") as f: + trace_json = json.load(f) + + self.assertTrue("traceEvents" in trace_json) + events = trace_json["traceEvents"] + + def nameMatchesLaunchKernel(event_name): + return "cuLaunchKernel" in event_name + + self.assertTrue( + any( + ("name" in event and "cuLaunchKernel" == event["name"]) + for event in events + ) + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/third_party/kineto b/third_party/kineto index e64df4dc312..21beef3787b 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit e64df4dc31285a6129a74d26d67365cedf7aa6d1 +Subproject commit 21beef3787b4134c43584f6c2443341921c41f69 diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 658fe80254a..1990280498b 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -24,6 +24,7 @@ const std::set kCpuTypes{ libkineto::ActivityType::EXTERNAL_CORRELATION, libkineto::ActivityType::XPU_RUNTIME, libkineto::ActivityType::CUDA_RUNTIME, + libkineto::ActivityType::CUDA_DRIVER, libkineto::ActivityType::PYTHON_FUNCTION, }; @@ -33,6 +34,7 @@ const std::set kCudaTypes = { libkineto::ActivityType::CONCURRENT_KERNEL, // CUDA_RUNTIME appears in both kCpuTypes and kCudaTypes. libkineto::ActivityType::CUDA_RUNTIME, + libkineto::ActivityType::CUDA_DRIVER, }; const std::set kXpuTypes = { libkineto::ActivityType::GPU_MEMCPY, @@ -321,6 +323,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::CPU_INSTANT_EVENT: case libkineto::ActivityType::GLOW_RUNTIME: case libkineto::ActivityType::PYTHON_FUNCTION: + case libkineto::ActivityType::CUDA_DRIVER: return c10::DeviceType::CPU; default: { TORCH_WARN(