diff --git a/tools/experimental/torchfuzz/codegen.py b/tools/experimental/torchfuzz/codegen.py index 3913e34b88c..da41392623b 100644 --- a/tools/experimental/torchfuzz/codegen.py +++ b/tools/experimental/torchfuzz/codegen.py @@ -501,6 +501,133 @@ class UnbackedFuzzTemplate(FuzzTemplate): return [] +class DTensorFuzzPlacementsTemplate(DTensorFuzzTemplate): + """DTensor template with randomized placements (Replicate, Shard, Partial). + + Extends DTensorFuzzTemplate to randomize placement strategies instead of + using fixed (Replicate(), Replicate()) for all tensors. + """ + + def fuzz_spec_custom(self): + """Generate tensor specs with minimum 1 dimension for proper DTensor sharding.""" + import random + + from torchfuzz.tensor_fuzzer import fuzz_valid_stride + + # Get random dtype + dtype = random.choice(self.supported_dtypes()) + + # Generate tensor size with minimum 1 dimension (avoid 0-dim scalars) + # Prefer 2D-3D tensors for interesting sharding patterns + ndim = random.choices([1, 2, 3, 4], weights=[0.1, 0.5, 0.3, 0.1])[0] + size = tuple(random.randint(2, 32) for _ in range(ndim)) + stride = fuzz_valid_stride(size) + + from torchfuzz.tensor_fuzzer import TensorSpec + + return TensorSpec(size=size, stride=stride, dtype=dtype) + + def imports_codegen(self): + """Add Partial to imports.""" + base_imports = super().imports_codegen() + # Update the placement imports to include Partial + for i, imp in enumerate(base_imports): + if "placement_types import" in imp: + base_imports[i] = ( + "from torch.distributed.tensor.placement_types import Replicate, Shard, Partial" + ) + break + base_imports.append("import torch.distributed.tensor as dist_tensor") + return base_imports + + def _generate_random_placement(self, tensor_size): + """Generate random placement tuple (Replicate, Shard, or Partial).""" + import random + + placements = [] + for _ in range(2): # 2D mesh + placement_type = random.randint(0, 2) + if placement_type == 0: + placements.append("Replicate()") + elif placement_type == 1 and len(tensor_size) > 0: + shard_dim = random.randint(0, len(tensor_size) - 1) + placements.append(f"Shard({shard_dim})") + else: + placements.append("Partial()" if placement_type == 2 else "Replicate()") + return f"({', '.join(placements)})" + + def args_codegen(self, arg_operations, constant_operations=None): + """Generate args with randomized placements using dist_tensor API.""" + + code_lines = [] + + # DTensor setup (same as parent) + code_lines.extend( + [ + "world_size = 1024", + "fake_store = FakeStore()", + "torch.distributed.init_process_group(", + ' "fake", store=fake_store, rank=0, world_size=world_size', + ")", + "", + "mesh = torch.distributed.device_mesh.init_device_mesh(", + ' "cuda", (2, 8), mesh_dim_names=("dim1", "dim2")', + ")", + "", + ] + ) + + # Sentinel with random placement + sentinel_placements = self._generate_random_placement((1,)) + code_lines.extend( + [ + f"sentinel = dist_tensor.ones((1,), device_mesh=mesh, placements={sentinel_placements}, dtype=torch.float32, requires_grad=True)", + "", + ] + ) + + # Args with random placements using dist_tensor API + if arg_operations: + for i, (node_id, spec) in enumerate(arg_operations): + if isinstance(spec, TensorSpec): + size_str = str(spec.size) + dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.") + placements = self._generate_random_placement(spec.size) + + if spec.dtype in [ + torch.int32, + torch.int64, + torch.int8, + torch.int16, + ]: + code_lines.append( + f"arg_{i} = dist_tensor.ones({size_str}, device_mesh=mesh, placements={placements}, dtype={dtype_str}) * 5" + ) + elif spec.dtype == torch.bool: + code_lines.append( + f"arg_{i} = dist_tensor.ones({size_str}, device_mesh=mesh, placements={placements}, dtype=torch.int8).bool()" + ) + else: + code_lines.append( + f"arg_{i} = dist_tensor.randn({size_str}, device_mesh=mesh, placements={placements}, dtype={dtype_str}, requires_grad=True)" + ) + + # Constants (if any) - use same dist_tensor approach + if constant_operations: + for node_id, var_name, spec in constant_operations: + if isinstance(spec, TensorSpec): + size_str = str(spec.size) + dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.") + placements = self._generate_random_placement(spec.size) + # Use dist_tensor.full with a simple fill value + code_lines.append( + f"{var_name} = dist_tensor.full({size_str}, 1.0, device_mesh=mesh, placements={placements}, dtype={dtype_str})" + ) + + code_lines.append("") + return code_lines + + def convert_graph_to_python_code( operation_graph: OperationGraph, seed: int | None = None, @@ -525,6 +652,8 @@ def convert_graph_to_python_code( # Instantiate template if template == "dtensor": fuzz_template = DTensorFuzzTemplate() + elif template == "dtensor_placements": + fuzz_template = DTensorFuzzPlacementsTemplate() elif template == "unbacked": fuzz_template = UnbackedFuzzTemplate() else: @@ -543,12 +672,15 @@ def convert_graph_to_python_code( # Get topological order - this ensures dependencies are processed before dependents topo_order = operation_graph.get_topological_order() - # Track generated variables and arg operations + # Track generated variables, arg operations, and constant operations generated_code_lines = [] node_variables: dict[str, tuple[str, Spec]] = {} # Maps node_id to (var_name, spec) arg_operations: list[ tuple[str, Spec] ] = [] # List of (node_id, spec) for arg operations + constant_operations: list[ + tuple[str, str, Spec] + ] = [] # List of (node_id, var_name, spec) for constant operations (DTensor templates only) # Process nodes in topological order for node_id in topo_order: @@ -579,6 +711,13 @@ def convert_graph_to_python_code( # Add tensor descriptor comment for arg operations too descriptor_comment = f"# {format_tensor_descriptor(output_spec)}" operation_lines = [f"{output_var_name} = {arg_name} " + descriptor_comment] + elif op_name == "constant" and template == "dtensor_placements": + # For DTensor placements template, track constants to create them outside the function + constant_operations.append((node_id, output_var_name, output_spec)) + descriptor_comment = f"# {format_tensor_descriptor(output_spec)}" + operation_lines = [ + f"{output_var_name} = {output_var_name} " + descriptor_comment + ] else: # Generate operation execution code operation_lines = generate_simple_operation_code( @@ -598,12 +737,15 @@ def convert_graph_to_python_code( final_var_name, _ = node_variables[root_node_id] - # Generate function signature based on discovered arg operations + # Generate function signature based on discovered arg and constant operations + param_names = [] if arg_operations: - arg_names = [f"arg_{i}" for i in range(len(arg_operations))] - function_signature = f"def fuzzed_program({', '.join(arg_names)}, sentinel)" - else: - function_signature = "def fuzzed_program(sentinel)" + param_names.extend([f"arg_{i}" for i in range(len(arg_operations))]) + if template == "dtensor_placements" and constant_operations: + param_names.extend([var_name for _, var_name, _ in constant_operations]) + param_names.append("sentinel") + + function_signature = f"def fuzzed_program({', '.join(param_names)})" # Build the complete code - all imports at the top code_lines = [] @@ -627,7 +769,7 @@ def convert_graph_to_python_code( # Add return statement with sentinel multiplication to ensure gradient computation # Handle complex tensors appropriately based on template - if template == "dtensor": + if template in ["dtensor", "dtensor_placements"]: # For DTensor, avoid .real operation which doesn't work with sharding # Instead use abs() for complex tensors to get a real result code_lines.extend( @@ -653,23 +795,31 @@ def convert_graph_to_python_code( ) # Generate argument creation code using template - arg_code_lines = fuzz_template.args_codegen(arg_operations) - code_lines.extend(arg_code_lines) + if template == "dtensor_placements" and hasattr(fuzz_template, "args_codegen"): + # For dtensor_placements, pass constants to args_codegen which handles both + arg_code_lines = fuzz_template.args_codegen(arg_operations, constant_operations) + code_lines.extend(arg_code_lines) + else: + arg_code_lines = fuzz_template.args_codegen(arg_operations) + code_lines.extend(arg_code_lines) # Generate the final execution with both normal and compiled versions + param_values = [] if arg_operations: - arg_names = [f"arg_{i}" for i in range(len(arg_operations))] - if len(arg_names) == 1: - args_tuple = ( - f"({arg_names[0]},)" # Single element tuple needs trailing comma - ) - else: - args_tuple = f"({', '.join(arg_names)})" + param_values.extend([f"arg_{i}" for i in range(len(arg_operations))]) + if template == "dtensor_placements" and constant_operations: + param_values.extend([var_name for _, var_name, _ in constant_operations]) + param_values.append("sentinel") + + if len(param_values) == 1: + args_tuple = ( + f"({param_values[0]},)" # Single element tuple needs trailing comma + ) else: - args_tuple = "()" + args_tuple = f"({', '.join(param_values)})" # Generate execution code using template check - check_lines = fuzz_template.check.codegen(f"{args_tuple} + (sentinel,)") + check_lines = fuzz_template.check.codegen(args_tuple) code_lines.extend([""] + check_lines) # Add template epilogue diff --git a/tools/experimental/torchfuzz/fuzzer.py b/tools/experimental/torchfuzz/fuzzer.py index 50a00853f0a..e1e1be477ea 100644 --- a/tools/experimental/torchfuzz/fuzzer.py +++ b/tools/experimental/torchfuzz/fuzzer.py @@ -262,7 +262,7 @@ if __name__ == "__main__": ) parser.add_argument( "--template", - choices=["default", "dtensor", "unbacked"], + choices=["default", "dtensor", "dtensor_placements", "unbacked"], default="default", help="Template to use for code generation (default: default)", ) diff --git a/tools/experimental/torchfuzz/operators/constant.py b/tools/experimental/torchfuzz/operators/constant.py index 67419672c2a..18a35911463 100644 --- a/tools/experimental/torchfuzz/operators/constant.py +++ b/tools/experimental/torchfuzz/operators/constant.py @@ -116,12 +116,17 @@ class ConstantOperator(Operator): f"torch.full({size_str}, {fill_value}, dtype={dtype_str})" ) - # For DTensor template, convert to DTensor - if self.template == "dtensor": - return ( - f"{output_name}_local = {tensor_creation}.to('cuda')\n" - f" {output_name} = DTensor.from_local({output_name}_local, mesh, placements)" - ) + # For DTensor templates, constants are created outside the function + if self.template in ["dtensor", "dtensor_placements"]: + # For dtensor_placements, constants are handled in args_codegen + # For dtensor, use the global placements variable + if self.template == "dtensor_placements": + return f"# {output_name} is created globally" + else: + return ( + f"{output_name}_local = {tensor_creation}.to('cuda')\n" + f"{output_name} = DTensor.from_local({output_name}_local, mesh, placements)" + ) else: return f"{output_name} = {tensor_creation}" diff --git a/tools/experimental/torchfuzz/ops_fuzzer.py b/tools/experimental/torchfuzz/ops_fuzzer.py index dda3dc6efcf..9bea882bc51 100644 --- a/tools/experimental/torchfuzz/ops_fuzzer.py +++ b/tools/experimental/torchfuzz/ops_fuzzer.py @@ -43,6 +43,10 @@ def _get_template_filtered_operators( from torchfuzz.codegen import DTensorFuzzTemplate fuzz_template = DTensorFuzzTemplate() + elif template == "dtensor_placements": + from torchfuzz.codegen import DTensorFuzzPlacementsTemplate + + fuzz_template = DTensorFuzzPlacementsTemplate() elif template == "unbacked": from torchfuzz.codegen import UnbackedFuzzTemplate @@ -240,6 +244,10 @@ def fuzz_spec(template: str = "default") -> Spec: from torchfuzz.codegen import DTensorFuzzTemplate fuzz_template = DTensorFuzzTemplate() + elif template == "dtensor_placements": + from torchfuzz.codegen import DTensorFuzzPlacementsTemplate + + fuzz_template = DTensorFuzzPlacementsTemplate() elif template == "unbacked": from torchfuzz.codegen import UnbackedFuzzTemplate diff --git a/tools/experimental/torchfuzz/tensor_fuzzer.py b/tools/experimental/torchfuzz/tensor_fuzzer.py index 3ff71a03c2c..96d2990745f 100644 --- a/tools/experimental/torchfuzz/tensor_fuzzer.py +++ b/tools/experimental/torchfuzz/tensor_fuzzer.py @@ -52,6 +52,12 @@ def fuzz_torch_tensor_type(template: str = "default") -> torch.dtype: fuzz_template = DTensorFuzzTemplate() tensor_dtypes = fuzz_template.supported_dtypes() + elif template == "dtensor_placements": + # Import here to avoid circular imports + from torchfuzz.codegen import DTensorFuzzPlacementsTemplate + + fuzz_template = DTensorFuzzPlacementsTemplate() + tensor_dtypes = fuzz_template.supported_dtypes() elif template == "unbacked": # Import here to avoid circular imports from torchfuzz.codegen import UnbackedFuzzTemplate