From 3ca216ae172e35adde34a319a1a01faaf218e7c5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 1 Nov 2025 21:24:17 -0700 Subject: [PATCH] Add claude skills for uint support and AT_DISPATCH_V2 (#166814) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166814 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #166813 --- .claude/skills/add-uint-support/SKILL.md | 319 +++++++++++++++++++++++ .claude/skills/at-dispatch-v2/SKILL.md | 305 ++++++++++++++++++++++ .gitignore | 1 + 3 files changed, 625 insertions(+) create mode 100644 .claude/skills/add-uint-support/SKILL.md create mode 100644 .claude/skills/at-dispatch-v2/SKILL.md diff --git a/.claude/skills/add-uint-support/SKILL.md b/.claude/skills/add-uint-support/SKILL.md new file mode 100644 index 00000000000..a4859fdeae5 --- /dev/null +++ b/.claude/skills/add-uint-support/SKILL.md @@ -0,0 +1,319 @@ +--- +name: add-uint-support +description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support. +--- + +# Add Unsigned Integer (uint) Support to Operators + +This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros. + +## When to use this skill + +Use this skill when: +- Adding uint16, uint32, or uint64 support to an operator +- User mentions "unsigned types", "uint support", "barebones unsigned types" +- Enabling support for kUInt16, kUInt32, kUInt64 in kernels +- Working with operator implementations that need expanded type coverage + +## Quick reference + +**Add unsigned types to existing dispatch:** +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES)); + +// After (method 1: add unsigned types explicitly) +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + +// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present) +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)); +``` + +## Type group reference + +**Unsigned type groups:** +- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64 +- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES + +**Relationship:** +```cpp +AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort +AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64 +AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES +``` + +## Instructions + +### Step 1: Determine if conversion to V2 is needed + +Check if the file uses AT_DISPATCH_V2: + +**If using old AT_DISPATCH:** +- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill +- Then proceed with adding uint support + +**If already using AT_DISPATCH_V2:** +- Proceed directly to Step 2 + +### Step 2: Analyze the current dispatch macro + +Identify what type groups are currently in use: + +```cpp +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + // body +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + ^^^^^^^^^^^^^^^^^^^^^^^^^ + Current type coverage +``` + +Common patterns: +- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES +- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only +- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types + +### Step 3: Choose the uint addition method + +Two approaches: + +**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly** +- Use when: You want to be explicit about adding uint support +- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list + +**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2** +- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)` +- More concise: replaces one type group with its superset +- Only applicable if AT_INTEGRAL_TYPES is present + +### Step 4: Apply the transformation + +**Method 1 example:** +```cpp +// Before +AT_DISPATCH_V2( + dtype, + "min_values_cuda", + AT_WRAP([&]() { + kernel_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + kBFloat16, kHalf, kBool +); + +// After (add unsigned types) +AT_DISPATCH_V2( + dtype, + "min_values_cuda", + AT_WRAP([&]() { + kernel_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kBFloat16, kHalf, kBool +); +``` + +**Method 2 example:** +```cpp +// Before +AT_DISPATCH_V2( + dtype, + "integral_op", + AT_WRAP([&]() { + kernel(); + }), + AT_EXPAND(AT_INTEGRAL_TYPES) +); + +// After (substitute with V2) +AT_DISPATCH_V2( + dtype, + "integral_op", + AT_WRAP([&]() { + kernel(); + }), + AT_EXPAND(AT_INTEGRAL_TYPES_V2) +); +``` + +### Step 5: Handle AT_ALL_TYPES vs individual type groups + +If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`: +- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES` +- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list + +If the dispatch separately lists INTEGRAL and FLOATING: +```cpp +// Before +AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES) + +// After (Method 2 preferred) +AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES) +``` + +### Step 6: Verify all dispatch sites + +Check the file for ALL dispatch macros that need uint support: +- Some operators have multiple dispatch sites (CPU, CUDA, different functions) +- Apply the transformation consistently across all sites +- Ensure each gets the same type coverage updates + +### Step 7: Validate the changes + +Check that: +- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH) +- [ ] Unsigned types are added via one of the two methods +- [ ] All relevant dispatch sites in the file are updated +- [ ] Type groups use `AT_EXPAND()` +- [ ] Arguments are properly formatted and comma-separated + +## Common patterns + +### Pattern 1: AT_ALL_TYPES + extras + +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); +``` + +### Pattern 2: Separate INTEGRAL + FLOATING + +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)); +``` + +### Pattern 3: Old dispatch needs conversion first + +```cpp +// Before (needs v2 conversion first) +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() { + kernel(); +}); + +// After v2 conversion +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + +// After adding uint support +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); +``` + +## Multiple dispatch sites example + +For a file with multiple functions: + +```cpp +void min_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { + impl(iter); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + // Added uint support +} + +void min_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { + gpu_reduce_kernel(iter); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + // Added uint support here too +} +``` + +## Decision tree + +Use this decision tree to determine the approach: + +``` +Is the file using AT_DISPATCH_V2? +├─ No → Use at-dispatch-v2 skill first, then continue +└─ Yes + └─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)? + ├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2) + └─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list +``` + +## Edge cases + +### Case 1: Dispatch with only floating types + +If the operator only supports floating point types, don't add uint support: + +```cpp +// Leave as-is - floating point only operator +AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf); +``` + +### Case 2: Complex types present + +Unsigned types work alongside complex types: + +```cpp +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + kHalf, kBFloat16); +``` + +### Case 3: Already has uint support + +Check if uint types are already present: +- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support +- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support +- Skip the file if uint support is already present + +## Workflow + +When asked to add uint support: + +1. Read the target file +2. Check if using AT_DISPATCH_V2: + - If not → use at-dispatch-v2 skill first +3. Identify all dispatch macro sites +4. For each dispatch: + - Analyze current type groups + - Choose method (add BAREBONES_UNSIGNED or upgrade to V2) + - Apply transformation with Edit tool +5. Show the user the changes +6. Explain what was modified + +## Important notes + +- Always check if v2 conversion is needed first +- Apply changes consistently across all dispatch sites in the file +- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable +- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit +- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8) +- Some operators may not semantically support unsigned types - use judgment + +## Testing + +After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing. diff --git a/.claude/skills/at-dispatch-v2/SKILL.md b/.claude/skills/at-dispatch-v2/SKILL.md new file mode 100644 index 00000000000..eb9946c1d03 --- /dev/null +++ b/.claude/skills/at-dispatch-v2/SKILL.md @@ -0,0 +1,305 @@ +--- +name: at-dispatch-v2 +description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations. +--- + +# AT_DISPATCH to AT_DISPATCH_V2 Converter + +This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`. + +## When to use this skill + +Use this skill when: +- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2 +- Porting ATen kernels to use the new dispatch API +- Working with files in `aten/src/ATen/native/` that use dispatch macros +- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion + +## Quick reference + +**Old format:** +```cpp +AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() { + // lambda body +}); +``` + +**New format:** +```cpp +AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() { + // lambda body +}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool); +``` + +## Key transformations + +1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types +2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas +3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion +4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups +5. **Add include**: `#include ` near other Dispatch includes + +## Instructions + +### Step 1: Add the Dispatch_v2.h include + +Add the v2 header near the existing `#include `: + +```cpp +#include +#include +``` + +Keep the old Dispatch.h include for now (other code may still need it). + +### Step 2: Identify the old dispatch pattern + +Common patterns to convert: + +- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)` + +### Step 3: Map the old macro to type groups + +Identify which type group macro corresponds to the base types: + +| Old macro base | AT_DISPATCH_V2 type group | +|----------------|---------------------------| +| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` | +| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` | +| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` | +| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` | +| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` | + +For combined patterns, use multiple `AT_EXPAND()` entries: +```cpp +// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...) +// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2 +``` + +### Step 4: Extract the individual types + +From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.). + +These become the trailing arguments after the type group: +```cpp +AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool) + ^^^^^^^^^^^^^^^^^^^^^^^^ + Individual types from AND3 +``` + +### Step 5: Transform to AT_DISPATCH_V2 + +Apply the transformation: + +**Pattern:** +```cpp +AT_DISPATCH_V2( + scalar_type, // 1st: The dtype expression + "name", // 2nd: The debug string + AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP + type_groups, // 4th+: Type groups with AT_EXPAND() + individual_types // Last: Individual types +) +``` + +**Example transformation:** +```cpp +// BEFORE +AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, + iter.dtype(), + "min_values_cuda", + [&]() { + min_values_kernel_cuda_impl(iter); + } +); + +// AFTER +AT_DISPATCH_V2( + iter.dtype(), + "min_values_cuda", + AT_WRAP([&]() { + min_values_kernel_cuda_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + kBFloat16, kHalf, kBool +); +``` + +### Step 6: Handle multi-line lambdas + +For lambdas with internal commas or complex expressions, AT_WRAP is essential: + +```cpp +AT_DISPATCH_V2( + dtype, + "complex_kernel", + AT_WRAP([&]() { + gpu_reduce_kernel( + iter, + MinOps{}, + thrust::pair(upper_bound(), 0) // Commas inside! + ); + }), + AT_EXPAND(AT_ALL_TYPES) +); +``` + +### Step 7: Verify the conversion + +Check that: +- [ ] `AT_WRAP()` wraps the entire lambda +- [ ] Type groups use `AT_EXPAND()` +- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`) +- [ ] Argument order is: scalar_type, name, lambda, types +- [ ] Include added: `#include ` + +## Type group reference + +Available type group macros (use with `AT_EXPAND()`): + +```cpp +AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort +AT_FLOATING_TYPES // kDouble, kFloat +AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat +AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32 +AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES +AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES +AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types +AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64 +AT_FLOAT8_TYPES // Float8 variants +``` + +## Common patterns + +### Pattern: AT_DISPATCH_ALL_TYPES_AND2 + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() { + kernel(data); +}); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(data); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); +``` + +### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3 + +```cpp +// Before +AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn, + tensor.scalar_type(), "float_op", [&] { + process(tensor); +}); + +// After +AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] { + process(tensor); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn); +``` + +### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2 + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kComplexHalf, kHalf, + self.scalar_type(), + "complex_op", + [&] { + result = compute(self); + } +); + +// After +AT_DISPATCH_V2( + self.scalar_type(), + "complex_op", + AT_WRAP([&] { + result = compute(self); + }), + AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + kComplexHalf, + kHalf +); +``` + +## Edge cases + +### Case 1: No extra types (rare) + +```cpp +// Before +AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel(); }); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES)); +``` + +### Case 2: Many individual types (AND4, AND5, etc.) + +```cpp +// Before +AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, + dtype, "float8_op", [&]() { kernel(); }); + +// After +AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2); +``` + +### Case 3: Lambda with no captures + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() { + static_kernel(); +}); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() { + static_kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool); +``` + +## Benefits of AT_DISPATCH_V2 + +1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4 +2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()` +3. **Extensible**: Easy to add more types without hitting macro limits +4. **Clearer**: Type groups are explicit, not implicit in macro name + +## Important notes + +- Keep `#include ` - other code may need it +- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda +- Type groups need `AT_EXPAND()`, individual types don't +- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs +- See the header file for the Python script to regenerate the macro implementation + +## Workflow + +When asked to convert AT_DISPATCH macros: + +1. Read the file to identify all AT_DISPATCH uses +2. Add `#include ` if not present +3. For each dispatch macro: + - Identify the pattern and extract components + - Map the base type group + - Extract individual types + - Construct the AT_DISPATCH_V2 call + - Apply with Edit tool +4. Show the user the complete converted file +5. Explain what was changed + +Do NOT compile or test the code - focus on accurate conversion only. diff --git a/.gitignore b/.gitignore index e13973e86c2..d1b3b17445d 100644 --- a/.gitignore +++ b/.gitignore @@ -398,3 +398,4 @@ CLAUDE.local.md /test_*.py /debug_*.py CLAUDE_CONTEXT/ +/.claude/settings.local.json