mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Add interpreter for CollectiveBroadcastOp
This mirrors the upstream PR: https://github.com/openxla/stablehlo/pull/1983 PiperOrigin-RevId: 605152611
This commit is contained in:
committed by
TensorFlower Gardener
parent
c0d9032ac0
commit
ace8c3d12d
324
third_party/stablehlo/temporary.patch
vendored
324
third_party/stablehlo/temporary.patch
vendored
@@ -163,6 +163,18 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Directory setup
|
||||
diff --ruN a/stablehlo/docs/status.md b/stablehlo/docs/status.md
|
||||
--- stablehlo/docs/status.md
|
||||
+++ stablehlo/docs/status.md
|
||||
@@ -61,7 +61,7 @@
|
||||
| ceil | yes | yes | yes | yes | yes |
|
||||
| cholesky | yes | yes | yes | yes | revisit |
|
||||
| clamp | yes | revisit | yes | yes | yes |
|
||||
-| collective_broadcast | yes | revisit | yes | no | no |
|
||||
+| collective_broadcast | yes | revisit | yes | no | yes |
|
||||
| collective_permute | yes | revisit | yes | no | yes |
|
||||
| compare | yes | yes | yes | yes | yes |
|
||||
| complex | yes | yes | yes | yes | yes |
|
||||
diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt
|
||||
--- stablehlo/stablehlo/CMakeLists.txt
|
||||
+++ stablehlo/stablehlo/CMakeLists.txt
|
||||
@@ -2548,4 +2560,316 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
|
||||
+} // namespace experimental
|
||||
+} // namespace stablehlo
|
||||
+} // namespace mlir
|
||||
diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp
|
||||
--- stablehlo/stablehlo/reference/Ops.cpp
|
||||
+++ stablehlo/stablehlo/reference/Ops.cpp
|
||||
@@ -328,6 +328,25 @@
|
||||
auto operand = scope.findTensor(clzOp.getOperand());
|
||||
auto result = evalClzOp(operand, clzOp.getType());
|
||||
scope.add(clzOp.getResult(), result);
|
||||
+ } else if (auto collectiveBroadcastOp =
|
||||
+ dyn_cast<CollectiveBroadcastOp>(op)) {
|
||||
+ auto operand = scope.findTensor(collectiveBroadcastOp.getOperand());
|
||||
+
|
||||
+ auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups();
|
||||
+ auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
|
||||
+ SmallVector<SmallVector<uint32_t>> replicaGroups(replicaGroupsShape[0]);
|
||||
+ auto replicaGroupsIt = replicaGroupsAttr.getValues<int64_t>().begin();
|
||||
+ for (auto &replicaGroup : replicaGroups)
|
||||
+ for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt)
|
||||
+ replicaGroup.push_back(*replicaGroupsIt);
|
||||
+
|
||||
+ ChannelId channelId = 0;
|
||||
+ if (auto channelHandle = collectiveBroadcastOp.getChannelHandle())
|
||||
+ channelId = channelHandle->getHandle();
|
||||
+
|
||||
+ auto result =
|
||||
+ evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process);
|
||||
+ scope.add(collectiveBroadcastOp.getResult(), result);
|
||||
} else if (auto collectivePermuteOp = dyn_cast<CollectivePermuteOp>(op)) {
|
||||
auto operand = scope.findTensor(collectivePermuteOp.getOperand());
|
||||
|
||||
@@ -1074,6 +1093,28 @@
|
||||
return result;
|
||||
}
|
||||
|
||||
+Tensor evalCollectiveBroadcastOp(
|
||||
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
|
||||
+ ChannelId channelId, Process *process) {
|
||||
+ if (!process)
|
||||
+ llvm::report_fatal_error(
|
||||
+ "collective_broadcast is only supported when run via "
|
||||
+ "interpreter.run_parallel");
|
||||
+
|
||||
+ ProcessGroups processGroups;
|
||||
+ if (channelId <= 0) processGroups = process->crossReplica(replicaGroups);
|
||||
+ if (channelId > 0) processGroups = process->crossPartition(replicaGroups);
|
||||
+
|
||||
+ auto processGroup = processGroups.findGroup(process->getId());
|
||||
+ if (processGroup)
|
||||
+ return process->rendezvous(*processGroup, channelId, operand)
|
||||
+ .lookup((*processGroup)[0]);
|
||||
+
|
||||
+ return evalBroadcastInDimOp(
|
||||
+ makeScalar(convert(operand.getElementType(), 0.0)), {},
|
||||
+ operand.getType());
|
||||
+}
|
||||
+
|
||||
Tensor evalCollectivePermuteOp(
|
||||
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
|
||||
ChannelId channelId, Process *process) {
|
||||
diff --ruN a/stablehlo/stablehlo/reference/Ops.h b/stablehlo/stablehlo/reference/Ops.h
|
||||
--- stablehlo/stablehlo/reference/Ops.h
|
||||
+++ stablehlo/stablehlo/reference/Ops.h
|
||||
@@ -62,6 +62,9 @@
|
||||
Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max,
|
||||
ShapedType resultType);
|
||||
Tensor evalClzOp(const Tensor &operand, ShapedType resultType);
|
||||
+Tensor evalCollectiveBroadcastOp(
|
||||
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
|
||||
+ ChannelId channelId, Process *process);
|
||||
Tensor evalCollectivePermuteOp(
|
||||
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
|
||||
ChannelId channelId, Process *process);
|
||||
diff --ruN a/stablehlo/stablehlo/reference/ProcessGrid.cpp b/stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
--- stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
+++ stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
@@ -49,8 +49,8 @@
|
||||
|
||||
std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
|
||||
for (auto processGroup : *this)
|
||||
- for (auto id : processGroup)
|
||||
- if (id == processId) return processGroup;
|
||||
+ if (llvm::find(processGroup, processId) != processGroup.end())
|
||||
+ return processGroup;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
diff --ruN a/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
--- stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
+++ stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
@@ -0,0 +1,223 @@
|
||||
+// RUN: stablehlo-translate --interpret -split-input-file %s
|
||||
+
|
||||
+module @cross_replica {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_multiple_output {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_single_replica {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_multiple_partitions {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast],
|
||||
+ [@collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_multiple_output {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_single_partition {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_multiple_replicas {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast],
|
||||
+ [@collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
|
||||
|
||||
@@ -163,6 +163,18 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Directory setup
|
||||
diff --ruN a/stablehlo/docs/status.md b/stablehlo/docs/status.md
|
||||
--- stablehlo/docs/status.md
|
||||
+++ stablehlo/docs/status.md
|
||||
@@ -61,7 +61,7 @@
|
||||
| ceil | yes | yes | yes | yes | yes |
|
||||
| cholesky | yes | yes | yes | yes | revisit |
|
||||
| clamp | yes | revisit | yes | yes | yes |
|
||||
-| collective_broadcast | yes | revisit | yes | no | no |
|
||||
+| collective_broadcast | yes | revisit | yes | no | yes |
|
||||
| collective_permute | yes | revisit | yes | no | yes |
|
||||
| compare | yes | yes | yes | yes | yes |
|
||||
| complex | yes | yes | yes | yes | yes |
|
||||
diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt
|
||||
--- stablehlo/stablehlo/CMakeLists.txt
|
||||
+++ stablehlo/stablehlo/CMakeLists.txt
|
||||
@@ -2548,4 +2560,316 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
|
||||
+} // namespace experimental
|
||||
+} // namespace stablehlo
|
||||
+} // namespace mlir
|
||||
diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp
|
||||
--- stablehlo/stablehlo/reference/Ops.cpp
|
||||
+++ stablehlo/stablehlo/reference/Ops.cpp
|
||||
@@ -328,6 +328,25 @@
|
||||
auto operand = scope.findTensor(clzOp.getOperand());
|
||||
auto result = evalClzOp(operand, clzOp.getType());
|
||||
scope.add(clzOp.getResult(), result);
|
||||
+ } else if (auto collectiveBroadcastOp =
|
||||
+ dyn_cast<CollectiveBroadcastOp>(op)) {
|
||||
+ auto operand = scope.findTensor(collectiveBroadcastOp.getOperand());
|
||||
+
|
||||
+ auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups();
|
||||
+ auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
|
||||
+ SmallVector<SmallVector<uint32_t>> replicaGroups(replicaGroupsShape[0]);
|
||||
+ auto replicaGroupsIt = replicaGroupsAttr.getValues<int64_t>().begin();
|
||||
+ for (auto &replicaGroup : replicaGroups)
|
||||
+ for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt)
|
||||
+ replicaGroup.push_back(*replicaGroupsIt);
|
||||
+
|
||||
+ ChannelId channelId = 0;
|
||||
+ if (auto channelHandle = collectiveBroadcastOp.getChannelHandle())
|
||||
+ channelId = channelHandle->getHandle();
|
||||
+
|
||||
+ auto result =
|
||||
+ evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process);
|
||||
+ scope.add(collectiveBroadcastOp.getResult(), result);
|
||||
} else if (auto collectivePermuteOp = dyn_cast<CollectivePermuteOp>(op)) {
|
||||
auto operand = scope.findTensor(collectivePermuteOp.getOperand());
|
||||
|
||||
@@ -1074,6 +1093,28 @@
|
||||
return result;
|
||||
}
|
||||
|
||||
+Tensor evalCollectiveBroadcastOp(
|
||||
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
|
||||
+ ChannelId channelId, Process *process) {
|
||||
+ if (!process)
|
||||
+ llvm::report_fatal_error(
|
||||
+ "collective_broadcast is only supported when run via "
|
||||
+ "interpreter.run_parallel");
|
||||
+
|
||||
+ ProcessGroups processGroups;
|
||||
+ if (channelId <= 0) processGroups = process->crossReplica(replicaGroups);
|
||||
+ if (channelId > 0) processGroups = process->crossPartition(replicaGroups);
|
||||
+
|
||||
+ auto processGroup = processGroups.findGroup(process->getId());
|
||||
+ if (processGroup)
|
||||
+ return process->rendezvous(*processGroup, channelId, operand)
|
||||
+ .lookup((*processGroup)[0]);
|
||||
+
|
||||
+ return evalBroadcastInDimOp(
|
||||
+ makeScalar(convert(operand.getElementType(), 0.0)), {},
|
||||
+ operand.getType());
|
||||
+}
|
||||
+
|
||||
Tensor evalCollectivePermuteOp(
|
||||
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
|
||||
ChannelId channelId, Process *process) {
|
||||
diff --ruN a/stablehlo/stablehlo/reference/Ops.h b/stablehlo/stablehlo/reference/Ops.h
|
||||
--- stablehlo/stablehlo/reference/Ops.h
|
||||
+++ stablehlo/stablehlo/reference/Ops.h
|
||||
@@ -62,6 +62,9 @@
|
||||
Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max,
|
||||
ShapedType resultType);
|
||||
Tensor evalClzOp(const Tensor &operand, ShapedType resultType);
|
||||
+Tensor evalCollectiveBroadcastOp(
|
||||
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
|
||||
+ ChannelId channelId, Process *process);
|
||||
Tensor evalCollectivePermuteOp(
|
||||
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
|
||||
ChannelId channelId, Process *process);
|
||||
diff --ruN a/stablehlo/stablehlo/reference/ProcessGrid.cpp b/stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
--- stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
+++ stablehlo/stablehlo/reference/ProcessGrid.cpp
|
||||
@@ -49,8 +49,8 @@
|
||||
|
||||
std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
|
||||
for (auto processGroup : *this)
|
||||
- for (auto id : processGroup)
|
||||
- if (id == processId) return processGroup;
|
||||
+ if (llvm::find(processGroup, processId) != processGroup.end())
|
||||
+ return processGroup;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
diff --ruN a/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
--- stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
+++ stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
|
||||
@@ -0,0 +1,223 @@
|
||||
+// RUN: stablehlo-translate --interpret -split-input-file %s
|
||||
+
|
||||
+module @cross_replica {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_multiple_output {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_single_replica {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_replica_multiple_partitions {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast],
|
||||
+ [@collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_multiple_output {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast,
|
||||
+ @collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_single_partition {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast], [@collective_broadcast],
|
||||
+ [@collective_broadcast], [@collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+module @cross_partition_multiple_replicas {
|
||||
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
|
||||
+ %result = "stablehlo.collective_broadcast"(%operand) {
|
||||
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
|
||||
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
|
||||
+ return %result : tensor<1x2xi64>
|
||||
+ }
|
||||
+ func.func @main() {
|
||||
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
|
||||
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
|
||||
+ programs=[[@collective_broadcast, @collective_broadcast],
|
||||
+ [@collective_broadcast, @collective_broadcast]]
|
||||
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
|
||||
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
|
||||
+ check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
|
||||
+ func.return
|
||||
+ }
|
||||
+}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user