mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[profiler] support cuLaunchKernel (for triton kernel launches) & update kineto submodule (#99571)
**Background**: Prior to this PR, traces for PT2 w/ inductor don't contain connections between CUDA kernels and the CPU launch site. This PR adds those connections. **Details**: Triton kernels launched by inductor use cuLaunchKernel instead of cudaLaunchKernel. cuLaunchKernel is part of the driver API, while cudaLaunchKernel is part of the runtime API. In order to support cuLaunchKernel, we added support in kineto (pytorch/kineto#752) to also start listening to driver events; hence why we need to update the kineto submodule. After the change in kineto, we just need to turn this on in the PyTorch repo by adding CUDA_DRIVER activity type to the CPU and CUDA activity type lists; then **Testing**: Added test/inductor/test_profiler.py to check for `cuLaunchKernel` in json trace files. Also, I ran this test: ```python import torch x = torch.rand((2, 2), device='cuda') def fn(x): return x.relu() fn_c = torch.compile(fn) fn_c(x) with torch.profiler.profile(with_stack=True) as prof: fn_c(x) prof.export_chrome_trace("relu_profile.json") ``` which generated this chrometrace: <img width="930" alt="Screenshot 2023-04-18 at 2 58 25 PM" src="https://user-images.githubusercontent.com/5067123/232966895-b65f9daf-7645-44f8-9e2b-f8c11c86ef0a.png"> in which you can see flows between a `cuLaunchKernel` on the CPU side, and the triton kernel on the GPU. **Kineto Updates**: To get the kineto-side changes required for cupti driver events, this PR updates the kineto pin. In that updated kineto submodule, we also have: * JSON string sanitizing for event names (likely fix for #99572) * cuda initialization fixes for multiprocessing * cuKernelLaunch events (i.e. for this PR) * DISABLE_CUPTI_LAZY_REINIT (from @aaronenyeshi) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99571 Approved by: https://github.com/ngimel, https://github.com/aaronenyeshi
This commit is contained in:
committed by
PyTorch MergeBot
parent
5315317b7b
commit
c19d19f6ff
51
test/inductor/test_profiler.py
Normal file
51
test/inductor/test_profiler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._inductor.utils
|
||||
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
|
||||
HAS_TRITON = torch._inductor.utils.has_triton()
|
||||
|
||||
|
||||
class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
|
||||
def test_inductor_profiling_triton_launch(self):
|
||||
# Verify that we get some sort of CPU-side indication of triton kernel launches
|
||||
# in the profile traces. Currently, those appear as `cuLaunchKernel`. If this
|
||||
# detail changes, the test can be updated or removed.
|
||||
@torch.compile
|
||||
def fn(x, y):
|
||||
return (x + y).sin().cos()
|
||||
|
||||
x, y = [torch.rand((4, 4), device="cuda") for _ in range(2)]
|
||||
|
||||
with torch.profiler.profile() as prof:
|
||||
fn(x, y)
|
||||
|
||||
with TemporaryFileName(mode="w+") as fname:
|
||||
prof.export_chrome_trace(fname)
|
||||
with open(fname, "r") as f:
|
||||
trace_json = json.load(f)
|
||||
|
||||
self.assertTrue("traceEvents" in trace_json)
|
||||
events = trace_json["traceEvents"]
|
||||
|
||||
def nameMatchesLaunchKernel(event_name):
|
||||
return "cuLaunchKernel" in event_name
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
("name" in event and "cuLaunchKernel" == event["name"])
|
||||
for event in events
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: e64df4dc31...21beef3787
@@ -24,6 +24,7 @@ const std::set<libkineto::ActivityType> kCpuTypes{
|
||||
libkineto::ActivityType::EXTERNAL_CORRELATION,
|
||||
libkineto::ActivityType::XPU_RUNTIME,
|
||||
libkineto::ActivityType::CUDA_RUNTIME,
|
||||
libkineto::ActivityType::CUDA_DRIVER,
|
||||
libkineto::ActivityType::PYTHON_FUNCTION,
|
||||
};
|
||||
|
||||
@@ -33,6 +34,7 @@ const std::set<libkineto::ActivityType> kCudaTypes = {
|
||||
libkineto::ActivityType::CONCURRENT_KERNEL,
|
||||
// CUDA_RUNTIME appears in both kCpuTypes and kCudaTypes.
|
||||
libkineto::ActivityType::CUDA_RUNTIME,
|
||||
libkineto::ActivityType::CUDA_DRIVER,
|
||||
};
|
||||
const std::set<libkineto::ActivityType> kXpuTypes = {
|
||||
libkineto::ActivityType::GPU_MEMCPY,
|
||||
@@ -321,6 +323,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) {
|
||||
case libkineto::ActivityType::CPU_INSTANT_EVENT:
|
||||
case libkineto::ActivityType::GLOW_RUNTIME:
|
||||
case libkineto::ActivityType::PYTHON_FUNCTION:
|
||||
case libkineto::ActivityType::CUDA_DRIVER:
|
||||
return c10::DeviceType::CPU;
|
||||
default: {
|
||||
TORCH_WARN(
|
||||
|
||||
Reference in New Issue
Block a user