diff --git a/torchgen/gen.py b/torchgen/gen.py index 27ff0c48caa..e5870a24fc6 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -59,6 +59,7 @@ from torchgen.model import ( is_cuda_dispatch_key, is_generic_dispatch_key, is_ufunc_dispatch_key, + is_xpu_dispatch_key, Location, NativeFunction, NativeFunctionsGroup, @@ -184,7 +185,7 @@ def parse_native_yaml_struct( use_out_as_primary=True, external=False, # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), + device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), index=v, ) return ParsedYaml(rs, indices) diff --git a/torchgen/model.py b/torchgen/model.py index 7459587e31d..95694934310 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -323,6 +323,18 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: } +# XPU specific dispatcy keys +def is_xpu_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.XPU, + DispatchKey.QuantizedXPU, + DispatchKey.SparseXPU, + DispatchKey.SparseCsrXPU, + DispatchKey.NestedTensorXPU, + DispatchKey.AutogradXPU, + } + + # Structured kernel generation is only supported for certain key types; # otherwise use old-style def is_structured_dispatch_key(dk: DispatchKey) -> bool: