#sdy support JAX callbacks through the Shardy XLA round-trip pipeline.

PiperOrigin-RevId: 713646485
This commit is contained in:
Bart Chrzaszcz
2025-01-09 06:13:29 -08:00
committed by TensorFlower Gardener
parent 8f94e73026
commit cf43bb53b5
17 changed files with 484 additions and 9 deletions

View File

@@ -88,6 +88,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/ir:register",
"@stablehlo//:stablehlo_ops",
],
)
@@ -119,6 +120,7 @@ xla_cc_binary(
deps = [
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls",
"//xla/service/spmd/shardy/mhlo_round_trip:export_ops",
"//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export",
@@ -132,6 +134,7 @@ xla_cc_binary(
"//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
"//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls",
"//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes",

View File

@@ -38,6 +38,14 @@ inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName =
inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName =
"SPMDShardToFullShape";
// The target name of the Python CPU callback custom call.
inline constexpr llvm::StringRef kPythonCpuCallbackCustomCallTargetName =
"xla_python_cpu_callback";
// The target name of the Python GPU callback custom call.
inline constexpr llvm::StringRef kPythonGpuCallbackCustomCallTargetName =
"xla_python_gpu_callback";
// The attribute name for backend config.
inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config";

View File

@@ -83,11 +83,28 @@ cc_library(
],
)
cc_library(
name = "export_callback_custom_calls",
srcs = ["export_callback_custom_calls.cc"],
hdrs = ["export_callback_custom_calls.h"],
deps = [
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)
cc_library(
name = "mhlo_export",
srcs = ["mhlo_export.cc"],
hdrs = ["mhlo_export.h"],
deps = [
":export_callback_custom_calls",
":export_ops",
":export_shardings",
":shard_map_export",

View File

@@ -0,0 +1,120 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
#include <memory>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
namespace xla {
namespace sdy {
namespace {
using ::mlir::ModuleOp;
using ::mlir::OperationPass;
using ::mlir::PassWrapper;
using ::mlir::StringRef;
using ::mlir::stablehlo::CustomCallOp;
// Attempts to replace the `CustomCallOp` with a tuple version of it, and a
// `GetTupleElementOp` that gets the first element of the tuple.
//
// This only happens if the op has a single result and the result type is not
// a tuple.
void replaceCallbackWithTupleVersion(CustomCallOp customCall,
mlir::IRRewriter& rewriter) {
if (customCall.getNumResults() != 1 ||
mlir::isa<mlir::TupleType>(customCall->getResultTypes().front())) {
return;
}
CustomCallOp tupleCustomCall = cloneCustomCallWithNewResultTypes(
customCall,
mlir::TupleType::get(customCall->getContext(),
{customCall->getResultTypes()}),
rewriter);
auto getTupleElement = rewriter.create<mlir::stablehlo::GetTupleElementOp>(
customCall.getLoc(), customCall->getResultTypes().front(),
tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0));
getTupleElement->setAttr(kXlaShardingAttr,
customCall->getAttr(kXlaShardingAttr));
rewriter.replaceOp(customCall, getTupleElement);
}
class MhloRoundTripExportCallbackCustomCallsPass
: public PassWrapper<MhloRoundTripExportCallbackCustomCallsPass,
OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
MhloRoundTripExportCallbackCustomCallsPass)
void runOnOperation() final {
getOperation().walk([&](CustomCallOp customCall) {
if (!isPythonCallbackCustomCall(customCall)) {
return;
}
mlir::IRRewriter rewriter(customCall);
if (!customCall->use_empty()) {
replaceCallbackWithTupleVersion(customCall, rewriter);
return;
}
CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes(
customCall, mlir::TypeRange(), rewriter);
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({}));
rewriter.eraseOp(customCall);
return;
});
}
StringRef getArgument() const override {
return "xla-sdy-mhlo-round-trip-export-callback-custom-calls";
}
StringRef getDescription() const override {
return "Converts the `CustomCallOp`s for host callbacks in XLA into the "
"pattern that the XLA compiler recognizes.";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::sdy::SdyDialect>();
}
};
} // namespace
std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass() {
return std::make_unique<MhloRoundTripExportCallbackCustomCallsPass>();
}
void registerMhloRoundTripExportCallbackCustomCallsPass() {
mlir::registerPass(createMhloRoundTripExportCallbackCustomCallsPass);
}
} // namespace sdy
} // namespace xla

View File

@@ -0,0 +1,42 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
#include <memory>
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
namespace xla {
namespace sdy {
// Creates a pass that converts the `CustomCallOp`s for host callbacks in XLA
// into the pattern that the XLA compiler recognizes.
//
// The rest of the XLA pipeline expects host callback custom calls to either be
// a tuple with a get_tuple_element or no results (which we changed due to
// shardy shardings expecting at least one result, and needing to attach a
// maximal sharding to the callbacks).
std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass();
// Registers the xla-sdy-mhlo-round-trip-export-callback-custom-calls pass.
void registerMhloRoundTripExportCallbackCustomCallsPass();
} // namespace sdy
} // namespace xla
#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_

View File

@@ -86,6 +86,8 @@ using ::mlir::success;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;
using ::mlir::stablehlo::CustomCallOp;
using ::mlir::sdy::AxisRefAttr;
using ::mlir::sdy::DimensionShardingAttr;
using ::mlir::sdy::kShardingAttr;
@@ -197,6 +199,7 @@ class ExportMhloShardingsPass
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
mlir::SymbolTableCollection symbolTableCollection;
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);
@@ -208,10 +211,10 @@ class ExportMhloShardingsPass
}
}
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
// If they have a sharding annotation, we need to move it into
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
moduleOp.walk([&](mlir::stablehlo::CustomCallOp customCall) {
moduleOp.walk([&](CustomCallOp customCall) {
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
// If they have a sharding annotation, we need to move it into
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
StringRef callTargetName = customCall.getCallTargetName();
if (callTargetName != "mhlo.erf" && callTargetName != "mhlo.topk") {
return;

View File

@@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
@@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
pm.addPass(createMhloRoundTripShardMapExportPass());
pm.addPass(createExportNamedComputationsPass());
pm.addPass(createExportMhloShardingsPass());
pm.addPass(createMhloRoundTripExportCallbackCustomCallsPass());
}
void registerMhloExportPipeline() {

View File

@@ -23,6 +23,7 @@ limitations under the License.
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h"
@@ -36,6 +37,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
@@ -66,12 +68,14 @@ int main(int argc, char** argv) {
xla::sdy::registerMhloExportPipeline();
xla::sdy::registerMhloExportShardingsPass();
xla::sdy::registerMhloRoundTripExportCallbackCustomCallsPass();
xla::sdy::registerMhloRoundTripShardMapExportPass();
xla::sdy::registerExportNamedComputationsPass();
xla::sdy::registerExportOpsPass();
xla::sdy::registerSdyRoundTripMhloToHloToMhloPass();
xla::sdy::registerSdyRoundTripExportShardyAttrsPass();
xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass();
xla::sdy::registerSdyRoundTripImportShardyAttrsPass();
xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass();
xla::sdy::registerSdyRoundTripExportOpsPass();

View File

@@ -126,6 +126,22 @@ cc_library(
],
)
cc_library(
name = "import_callback_custom_calls",
srcs = ["import_callback_custom_calls.cc"],
hdrs = ["import_callback_custom_calls.h"],
deps = [
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)
cc_library(
name = "pipelines",
srcs = ["pipelines.cc"],
@@ -133,6 +149,7 @@ cc_library(
deps = [
":export_ops",
":export_shardy_attrs",
":import_callback_custom_calls",
":import_shardy_attrs",
":remove_size_one_axes",
":shard_map_export",
@@ -143,6 +160,5 @@ cc_library(
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

View File

@@ -0,0 +1,91 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include <memory>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/utils.h"
namespace xla {
namespace sdy {
namespace {
using ::mlir::ModuleOp;
using ::mlir::StringRef;
using ::mlir::stablehlo::CustomCallOp;
class SdyRoundTripImportCallbackCustomCallsPass
: public mlir::PassWrapper<SdyRoundTripImportCallbackCustomCallsPass,
mlir::OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
SdyRoundTripImportCallbackCustomCallsPass)
void runOnOperation() final {
getOperation().walk([&](CustomCallOp op) {
if (op->getNumResults() != 0 || !isPythonCallbackCustomCall(op)) {
return;
}
mlir::IRRewriter rewriter(op);
// Shardy needs at least one op result to have a sharding annotation.
// Since the callback has no results, and we need to say the callbacks
// have a maximal sharding, we add a dummy result and set the result
// layout to the 0th operand layout.
CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes(
op, op->getOperand(0).getType(), rewriter);
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr(
{op.getOperandLayoutsAttr().getValue().front()}));
rewriter.eraseOp(op);
});
}
StringRef getArgument() const override {
return "xla-sdy-round-trip-import-callback-custom-calls";
}
StringRef getDescription() const override {
return "Modifies the return types of XLA host callback custom calls to be "
"compatible with SDY";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::stablehlo::StablehloDialect>();
}
};
} // namespace
std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass() {
return std::make_unique<SdyRoundTripImportCallbackCustomCallsPass>();
}
void registerSdyRoundTripImportCallbackCustomCallsPass() {
mlir::registerPass(createSdyRoundTripImportCallbackCustomCallsPass);
}
} // namespace sdy
} // namespace xla

View File

@@ -0,0 +1,41 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
#include <memory>
#include "mlir/Pass/Pass.h"
namespace xla {
namespace sdy {
// Creates the pass to modify the return types of XLA host callback custom calls
// to be compatible with SDY.
//
// Shardy shardings require an op to have at least one result, and the XLA host
// callback custom calls are not guaranteed to return a value.
// To allow the custom calls to have a maximal sharding, we change the return
// type to return a dummy value.
std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass();
// Registers the xla-sdy-round-trip-import-callback-custom-calls pass.
void registerSdyRoundTripImportCallbackCustomCallsPass();
} // namespace sdy
} // namespace xla
#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_

View File

@@ -65,8 +65,6 @@ using ::mlir::StringRef;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;
using ::mlir::stablehlo::CustomCallOp;
using ::mlir::sdy::kShardingAttr;
using ::mlir::sdy::kShardingRuleAttr;
using ::mlir::sdy::MeshAttr;
@@ -74,6 +72,8 @@ using ::mlir::sdy::OpShardingRuleAttr;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;
namespace stablehlo = ::mlir::stablehlo;
// Builds the shardy attributes coming from Shardy previously. This means
// the module was exported from Shardy and we are now round-tripping back.
// This should happen after the meshes were created from the `ModuleOp` attrs
@@ -108,13 +108,19 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
if (!dictAttr) {
return;
}
// `SendOp` and `RecvOp` can have a sharding when doing TPU callbacks
// through JAX.
if (mlir::isa<stablehlo::SendOp, stablehlo::RecvOp>(op)) {
op->setAttr(kShardingAttr, parseStringAttr<TensorShardingPerValueAttr>(
dictAttr, kShardingRoundTripAttr));
}
// NOTE: we are only setting the sharding on known custom-calls. For any
// other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes
// creates new instructions, copying over the operand's frontend attrs,
// which may mean the shapes are wrong when the new instruction is a reshape
// for example. This does mean we can't fully round-trip b/w HLO and MLIR
// after SDY propagation.
if (auto customCallOp = mlir::dyn_cast<CustomCallOp>(op)) {
if (auto customCallOp = mlir::dyn_cast<stablehlo::CustomCallOp>(op)) {
StringRef targetName = customCallOp.getCallTargetName();
if (targetName == kFuncResultShardingTargetName) {
// This is a temporary CustomCallOp that holds the sharding from a
@@ -139,7 +145,8 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
}
if (targetName == kShardingCustomCallTargetName ||
targetName == kSPMDFullToShardShapeCallTargetName ||
targetName == kSPMDShardToFullShapeCallTargetName) {
targetName == kSPMDShardToFullShapeCallTargetName ||
isPythonCallbackCustomCall(customCallOp)) {
customCallOp->setAttr(kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr, kShardingRoundTripAttr));

View File

@@ -26,6 +26,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
@@ -49,6 +50,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {
void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) {
addCommonPreImportPasses(pm);
pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass());
pm.addPass(createSdyRoundTripImportShardyAttrsPass());
pm.addPass(createSdyRoundTripShardMapImportPass());
pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass());

View File

@@ -246,6 +246,74 @@ func.func @custom_call_erf_topk(
return %1#0 : tensor<16x2xf32>
}
// CHECK-LABEL: @callback_transform_to_tuple
func.func @callback_transform_to_tuple(%arg0: tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) -> (tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) {
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor<i64>, tensor<2xf64>) -> tuple<tensor<2xf64>>
// CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple<tensor<2xf64>>) -> tensor<2xf64>
// CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64>
%1 = stablehlo.constant dense<56560393354880> : tensor<i64>
%2 = stablehlo.custom_call @xla_python_cpu_callback(%1, %arg0) {api_version = 2 : i32, backend_config = "56560393354880", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{}]>]>, xla_shape = "(f64[2]{0})"} : (tensor<i64>, tensor<2xf64>) -> tensor<2xf64>
return %2 : tensor<2xf64>
}
// CHECK-LABEL: @callback_no_result
func.func private @callback_no_result(%arg0: tensor<f64>) {
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
// CHECK-SAME: result_layouts = []
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> ()
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, tensor<f64>) -> tuple<>
return
}
// CHECK-LABEL: @callback_result_unused
func.func private @callback_result_unused(%arg0: tensor<f64>) {
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
// CHECK-SAME: result_layouts = []
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> ()
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, tensor<f64>) -> tensor<i64>
return
}
// CHECK-LABEL: @callback_tuple_result_token_used
func.func public @callback_tuple_result_token_used(%arg0: !stablehlo.token, %arg1: tensor<2xi64>) -> !stablehlo.token {
%c = stablehlo.constant dense<56238119409280> : tensor<i64>
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0, %arg1) {
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238119409280",
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>],
// CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>]
// CHECK-SAME: } : (tensor<i64>, !stablehlo.token, tensor<2xi64>) -> tuple<!stablehlo.token>
// CHECK-NEXT: %[[TOKEN:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] : (tuple<!stablehlo.token>) -> !stablehlo.token
// CHECK-NEXT: return %[[TOKEN]] : !stablehlo.token
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0, %arg1) {api_version = 2 : i32, backend_config = "56238119409280", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, !stablehlo.token, tensor<2xi64>) -> tuple<!stablehlo.token>
%1 = stablehlo.get_tuple_element %0[0] : (tuple<!stablehlo.token>) -> !stablehlo.token
return %1 : !stablehlo.token
}
// CHECK-LABEL: @callback_no_tuple_result_used
func.func @callback_no_tuple_result_used(%arg0: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor<i64>, tensor<2xf64>) -> tuple<tensor<2xf64>>
// CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple<tensor<2xf64>>) -> tensor<2xf64>
// CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64>
%c = stablehlo.constant dense<18990036333952> : tensor<i64>
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "18990036333952", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{?}]>]>, xla_shape = "(f64[2]{0})"} : (tensor<i64>, tensor<2xf64>) -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// CHECK-LABEL: func private @foo
// CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}
// CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) {

View File

@@ -241,3 +241,18 @@ func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> ()
return %arg0 : tensor<8x8xf32>
}
// -----
func.func @callback_no_result(%arg0: tensor<f64>) {
// CHECK: %[[C:.*]] = sdy.constant
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
// CHECK-SAME: has_side_effect = true,
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
// CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>]
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> tensor<i64>
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []} : (tensor<i64>, tensor<f64>) -> ()
return
}

View File

@@ -30,9 +30,12 @@ limitations under the License.
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/spmd/shardy/constants.h"
@@ -50,6 +53,7 @@ using ::mlir::StringRef;
using xla::sdy::kFrontendAttributesAttr;
using ::mlir::func::FuncOp;
using ::mlir::stablehlo::CustomCallOp;
DictionaryAttr getFrontendAttrs(Operation* op) {
return op->getAttrOfType<DictionaryAttr>(kFrontendAttributesAttr);
@@ -185,5 +189,25 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) {
context->loadAllAvailableDialects();
}
CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op,
mlir::TypeRange resultTypes,
mlir::IRRewriter& rewriter) {
auto customCallOp = rewriter.create<CustomCallOp>(
op.getLoc(), resultTypes, op.getOperands(), op.getCallTargetNameAttr(),
op.getHasSideEffectAttr(), op.getBackendConfigAttr(),
op.getApiVersionAttr(), op.getCalledComputations(),
op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(),
op.getOutputOperandAliases());
customCallOp->setDiscardableAttrs(mlir::DictionaryAttr::get(
op->getContext(), llvm::to_vector(op->getDiscardableAttrs())));
return customCallOp;
};
bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op) {
mlir::StringRef targetName = op.getCallTargetName();
return targetName == kPythonCpuCallbackCustomCallTargetName ||
targetName == kPythonGpuCallbackCustomCallTargetName;
}
} // namespace sdy
} // namespace xla

View File

@@ -28,7 +28,10 @@ limitations under the License.
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace xla {
namespace sdy {
@@ -101,6 +104,15 @@ std::optional<AttrTy> tryGetFrontendAttr(mlir::Operation* op,
return std::nullopt;
}
// Builds a new `stablehlo.custom_call` with the same operands and attributes
// as `op` but with new `resultTypes`.
mlir::stablehlo::CustomCallOp cloneCustomCallWithNewResultTypes(
mlir::stablehlo::CustomCallOp op, mlir::TypeRange resultTypes,
mlir::IRRewriter& rewriter);
// Whether `op` is a Python callback custom call.
bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op);
} // namespace sdy
} // namespace xla