diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 80bfc3517bf..763df10ba51 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -163,6 +163,18 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/docs/status.md b/stablehlo/docs/status.md +--- stablehlo/docs/status.md ++++ stablehlo/docs/status.md +@@ -61,7 +61,7 @@ + | ceil | yes | yes | yes | yes | yes | + | cholesky | yes | yes | yes | yes | revisit | + | clamp | yes | revisit | yes | yes | yes | +-| collective_broadcast | yes | revisit | yes | no | no | ++| collective_broadcast | yes | revisit | yes | no | yes | + | collective_permute | yes | revisit | yes | no | yes | + | compare | yes | yes | yes | yes | yes | + | complex | yes | yes | yes | yes | yes | diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -2548,4 +2560,316 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp +--- stablehlo/stablehlo/reference/Ops.cpp ++++ stablehlo/stablehlo/reference/Ops.cpp +@@ -328,6 +328,25 @@ + auto operand = scope.findTensor(clzOp.getOperand()); + auto result = evalClzOp(operand, clzOp.getType()); + scope.add(clzOp.getResult(), result); ++ } else if (auto collectiveBroadcastOp = ++ dyn_cast(op)) { ++ auto operand = scope.findTensor(collectiveBroadcastOp.getOperand()); ++ ++ auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups(); ++ auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); ++ SmallVector> replicaGroups(replicaGroupsShape[0]); ++ auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); ++ for (auto &replicaGroup : replicaGroups) ++ for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt) ++ replicaGroup.push_back(*replicaGroupsIt); ++ ++ ChannelId channelId = 0; ++ if (auto channelHandle = collectiveBroadcastOp.getChannelHandle()) ++ channelId = channelHandle->getHandle(); ++ ++ auto result = ++ evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process); ++ scope.add(collectiveBroadcastOp.getResult(), result); + } else if (auto collectivePermuteOp = dyn_cast(op)) { + auto operand = scope.findTensor(collectivePermuteOp.getOperand()); + +@@ -1074,6 +1093,28 @@ + return result; + } + ++Tensor evalCollectiveBroadcastOp( ++ const Tensor &operand, SmallVector> replicaGroups, ++ ChannelId channelId, Process *process) { ++ if (!process) ++ llvm::report_fatal_error( ++ "collective_broadcast is only supported when run via " ++ "interpreter.run_parallel"); ++ ++ ProcessGroups processGroups; ++ if (channelId <= 0) processGroups = process->crossReplica(replicaGroups); ++ if (channelId > 0) processGroups = process->crossPartition(replicaGroups); ++ ++ auto processGroup = processGroups.findGroup(process->getId()); ++ if (processGroup) ++ return process->rendezvous(*processGroup, channelId, operand) ++ .lookup((*processGroup)[0]); ++ ++ return evalBroadcastInDimOp( ++ makeScalar(convert(operand.getElementType(), 0.0)), {}, ++ operand.getType()); ++} ++ + Tensor evalCollectivePermuteOp( + const Tensor &operand, SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process) { +diff --ruN a/stablehlo/stablehlo/reference/Ops.h b/stablehlo/stablehlo/reference/Ops.h +--- stablehlo/stablehlo/reference/Ops.h ++++ stablehlo/stablehlo/reference/Ops.h +@@ -62,6 +62,9 @@ + Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, + ShapedType resultType); + Tensor evalClzOp(const Tensor &operand, ShapedType resultType); ++Tensor evalCollectiveBroadcastOp( ++ const Tensor &operand, SmallVector> replicaGroups, ++ ChannelId channelId, Process *process); + Tensor evalCollectivePermuteOp( + const Tensor &operand, SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process); +diff --ruN a/stablehlo/stablehlo/reference/ProcessGrid.cpp b/stablehlo/stablehlo/reference/ProcessGrid.cpp +--- stablehlo/stablehlo/reference/ProcessGrid.cpp ++++ stablehlo/stablehlo/reference/ProcessGrid.cpp +@@ -49,8 +49,8 @@ + + std::optional ProcessGroups::findGroup(ProcessId processId) { + for (auto processGroup : *this) +- for (auto id : processGroup) +- if (id == processId) return processGroup; ++ if (llvm::find(processGroup, processId) != processGroup.end()) ++ return processGroup; + + return std::nullopt; + } +diff --ruN a/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir +--- stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir ++++ stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir +@@ -0,0 +1,223 @@ ++// RUN: stablehlo-translate --interpret -split-input-file %s ++ ++module @cross_replica { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_multiple_output { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_single_replica { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[0]]> : tensor<1x1xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_multiple_partitions { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast], ++ [@collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_multiple_output { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_single_partition { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[0]]> : tensor<1x1xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_multiple_replicas { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast], ++ [@collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 80bfc3517bf..763df10ba51 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -163,6 +163,18 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/docs/status.md b/stablehlo/docs/status.md +--- stablehlo/docs/status.md ++++ stablehlo/docs/status.md +@@ -61,7 +61,7 @@ + | ceil | yes | yes | yes | yes | yes | + | cholesky | yes | yes | yes | yes | revisit | + | clamp | yes | revisit | yes | yes | yes | +-| collective_broadcast | yes | revisit | yes | no | no | ++| collective_broadcast | yes | revisit | yes | no | yes | + | collective_permute | yes | revisit | yes | no | yes | + | compare | yes | yes | yes | yes | yes | + | complex | yes | yes | yes | yes | yes | diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -2548,4 +2560,316 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp +--- stablehlo/stablehlo/reference/Ops.cpp ++++ stablehlo/stablehlo/reference/Ops.cpp +@@ -328,6 +328,25 @@ + auto operand = scope.findTensor(clzOp.getOperand()); + auto result = evalClzOp(operand, clzOp.getType()); + scope.add(clzOp.getResult(), result); ++ } else if (auto collectiveBroadcastOp = ++ dyn_cast(op)) { ++ auto operand = scope.findTensor(collectiveBroadcastOp.getOperand()); ++ ++ auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups(); ++ auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); ++ SmallVector> replicaGroups(replicaGroupsShape[0]); ++ auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); ++ for (auto &replicaGroup : replicaGroups) ++ for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt) ++ replicaGroup.push_back(*replicaGroupsIt); ++ ++ ChannelId channelId = 0; ++ if (auto channelHandle = collectiveBroadcastOp.getChannelHandle()) ++ channelId = channelHandle->getHandle(); ++ ++ auto result = ++ evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process); ++ scope.add(collectiveBroadcastOp.getResult(), result); + } else if (auto collectivePermuteOp = dyn_cast(op)) { + auto operand = scope.findTensor(collectivePermuteOp.getOperand()); + +@@ -1074,6 +1093,28 @@ + return result; + } + ++Tensor evalCollectiveBroadcastOp( ++ const Tensor &operand, SmallVector> replicaGroups, ++ ChannelId channelId, Process *process) { ++ if (!process) ++ llvm::report_fatal_error( ++ "collective_broadcast is only supported when run via " ++ "interpreter.run_parallel"); ++ ++ ProcessGroups processGroups; ++ if (channelId <= 0) processGroups = process->crossReplica(replicaGroups); ++ if (channelId > 0) processGroups = process->crossPartition(replicaGroups); ++ ++ auto processGroup = processGroups.findGroup(process->getId()); ++ if (processGroup) ++ return process->rendezvous(*processGroup, channelId, operand) ++ .lookup((*processGroup)[0]); ++ ++ return evalBroadcastInDimOp( ++ makeScalar(convert(operand.getElementType(), 0.0)), {}, ++ operand.getType()); ++} ++ + Tensor evalCollectivePermuteOp( + const Tensor &operand, SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process) { +diff --ruN a/stablehlo/stablehlo/reference/Ops.h b/stablehlo/stablehlo/reference/Ops.h +--- stablehlo/stablehlo/reference/Ops.h ++++ stablehlo/stablehlo/reference/Ops.h +@@ -62,6 +62,9 @@ + Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, + ShapedType resultType); + Tensor evalClzOp(const Tensor &operand, ShapedType resultType); ++Tensor evalCollectiveBroadcastOp( ++ const Tensor &operand, SmallVector> replicaGroups, ++ ChannelId channelId, Process *process); + Tensor evalCollectivePermuteOp( + const Tensor &operand, SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process); +diff --ruN a/stablehlo/stablehlo/reference/ProcessGrid.cpp b/stablehlo/stablehlo/reference/ProcessGrid.cpp +--- stablehlo/stablehlo/reference/ProcessGrid.cpp ++++ stablehlo/stablehlo/reference/ProcessGrid.cpp +@@ -49,8 +49,8 @@ + + std::optional ProcessGroups::findGroup(ProcessId processId) { + for (auto processGroup : *this) +- for (auto id : processGroup) +- if (id == processId) return processGroup; ++ if (llvm::find(processGroup, processId) != processGroup.end()) ++ return processGroup; + + return std::nullopt; + } +diff --ruN a/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir +--- stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir ++++ stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir +@@ -0,0 +1,223 @@ ++// RUN: stablehlo-translate --interpret -split-input-file %s ++ ++module @cross_replica { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_multiple_output { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_single_replica { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[0]]> : tensor<1x1xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_replica_multiple_partitions { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast], ++ [@collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_multiple_output { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast, ++ @collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_single_partition { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[0]]> : tensor<1x1xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast], [@collective_broadcast], ++ [@collective_broadcast], [@collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++} ++ ++// ----- ++ ++module @cross_partition_multiple_replicas { ++ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { ++ %result = "stablehlo.collective_broadcast"(%operand) { ++ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, ++ channel_handle = #stablehlo.channel_handle ++ } : (tensor<1x2xi64>) -> tensor<1x2xi64> ++ return %result : tensor<1x2xi64> ++ } ++ func.func @main() { ++ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> ++ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> ++ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> ++ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> ++ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { ++ programs=[[@collective_broadcast, @collective_broadcast], ++ [@collective_broadcast, @collective_broadcast]] ++ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> ++ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ++ check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64> ++ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> ++ func.return ++ } ++}