mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
#sdy support JAX callbacks through the Shardy XLA round-trip pipeline.
PiperOrigin-RevId: 713646485
This commit is contained in:
committed by
TensorFlower Gardener
parent
8f94e73026
commit
cf43bb53b5
@@ -88,6 +88,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
"@shardy//shardy/dialect/sdy/ir:dialect",
|
||||
"@shardy//shardy/dialect/sdy/ir:register",
|
||||
"@stablehlo//:stablehlo_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -119,6 +120,7 @@ xla_cc_binary(
|
||||
deps = [
|
||||
"//xla/mlir_hlo",
|
||||
"//xla/mlir_hlo:mhlo_passes",
|
||||
"//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls",
|
||||
"//xla/service/spmd/shardy/mhlo_round_trip:export_ops",
|
||||
"//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",
|
||||
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export",
|
||||
@@ -132,6 +134,7 @@ xla_cc_binary(
|
||||
"//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
|
||||
"//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes",
|
||||
|
||||
@@ -38,6 +38,14 @@ inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName =
|
||||
inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName =
|
||||
"SPMDShardToFullShape";
|
||||
|
||||
// The target name of the Python CPU callback custom call.
|
||||
inline constexpr llvm::StringRef kPythonCpuCallbackCustomCallTargetName =
|
||||
"xla_python_cpu_callback";
|
||||
|
||||
// The target name of the Python GPU callback custom call.
|
||||
inline constexpr llvm::StringRef kPythonGpuCallbackCustomCallTargetName =
|
||||
"xla_python_gpu_callback";
|
||||
|
||||
// The attribute name for backend config.
|
||||
inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config";
|
||||
|
||||
|
||||
@@ -83,11 +83,28 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "export_callback_custom_calls",
|
||||
srcs = ["export_callback_custom_calls.cc"],
|
||||
hdrs = ["export_callback_custom_calls.h"],
|
||||
deps = [
|
||||
"//xla/service/spmd/shardy:constants",
|
||||
"//xla/service/spmd/shardy:utils",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@shardy//shardy/dialect/sdy/ir:dialect",
|
||||
"@stablehlo//:stablehlo_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mhlo_export",
|
||||
srcs = ["mhlo_export.cc"],
|
||||
hdrs = ["mhlo_export.h"],
|
||||
deps = [
|
||||
":export_callback_custom_calls",
|
||||
":export_ops",
|
||||
":export_shardings",
|
||||
":shard_map_export",
|
||||
|
||||
120
third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc
vendored
Normal file
120
third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "shardy/dialect/sdy/ir/dialect.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "xla/service/spmd/shardy/constants.h"
|
||||
#include "xla/service/spmd/shardy/utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace sdy {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mlir::ModuleOp;
|
||||
using ::mlir::OperationPass;
|
||||
using ::mlir::PassWrapper;
|
||||
using ::mlir::StringRef;
|
||||
|
||||
using ::mlir::stablehlo::CustomCallOp;
|
||||
|
||||
// Attempts to replace the `CustomCallOp` with a tuple version of it, and a
|
||||
// `GetTupleElementOp` that gets the first element of the tuple.
|
||||
//
|
||||
// This only happens if the op has a single result and the result type is not
|
||||
// a tuple.
|
||||
void replaceCallbackWithTupleVersion(CustomCallOp customCall,
|
||||
mlir::IRRewriter& rewriter) {
|
||||
if (customCall.getNumResults() != 1 ||
|
||||
mlir::isa<mlir::TupleType>(customCall->getResultTypes().front())) {
|
||||
return;
|
||||
}
|
||||
CustomCallOp tupleCustomCall = cloneCustomCallWithNewResultTypes(
|
||||
customCall,
|
||||
mlir::TupleType::get(customCall->getContext(),
|
||||
{customCall->getResultTypes()}),
|
||||
rewriter);
|
||||
auto getTupleElement = rewriter.create<mlir::stablehlo::GetTupleElementOp>(
|
||||
customCall.getLoc(), customCall->getResultTypes().front(),
|
||||
tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0));
|
||||
getTupleElement->setAttr(kXlaShardingAttr,
|
||||
customCall->getAttr(kXlaShardingAttr));
|
||||
rewriter.replaceOp(customCall, getTupleElement);
|
||||
}
|
||||
|
||||
class MhloRoundTripExportCallbackCustomCallsPass
|
||||
: public PassWrapper<MhloRoundTripExportCallbackCustomCallsPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
MhloRoundTripExportCallbackCustomCallsPass)
|
||||
|
||||
void runOnOperation() final {
|
||||
getOperation().walk([&](CustomCallOp customCall) {
|
||||
if (!isPythonCallbackCustomCall(customCall)) {
|
||||
return;
|
||||
}
|
||||
mlir::IRRewriter rewriter(customCall);
|
||||
if (!customCall->use_empty()) {
|
||||
replaceCallbackWithTupleVersion(customCall, rewriter);
|
||||
return;
|
||||
}
|
||||
CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes(
|
||||
customCall, mlir::TypeRange(), rewriter);
|
||||
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({}));
|
||||
rewriter.eraseOp(customCall);
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
StringRef getArgument() const override {
|
||||
return "xla-sdy-mhlo-round-trip-export-callback-custom-calls";
|
||||
}
|
||||
|
||||
StringRef getDescription() const override {
|
||||
return "Converts the `CustomCallOp`s for host callbacks in XLA into the "
|
||||
"pattern that the XLA compiler recognizes.";
|
||||
}
|
||||
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const final {
|
||||
registry.insert<mlir::sdy::SdyDialect>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass() {
|
||||
return std::make_unique<MhloRoundTripExportCallbackCustomCallsPass>();
|
||||
}
|
||||
|
||||
void registerMhloRoundTripExportCallbackCustomCallsPass() {
|
||||
mlir::registerPass(createMhloRoundTripExportCallbackCustomCallsPass);
|
||||
}
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
42
third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h
vendored
Normal file
42
third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
namespace xla {
|
||||
namespace sdy {
|
||||
|
||||
// Creates a pass that converts the `CustomCallOp`s for host callbacks in XLA
|
||||
// into the pattern that the XLA compiler recognizes.
|
||||
//
|
||||
// The rest of the XLA pipeline expects host callback custom calls to either be
|
||||
// a tuple with a get_tuple_element or no results (which we changed due to
|
||||
// shardy shardings expecting at least one result, and needing to attach a
|
||||
// maximal sharding to the callbacks).
|
||||
std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass();
|
||||
|
||||
// Registers the xla-sdy-mhlo-round-trip-export-callback-custom-calls pass.
|
||||
void registerMhloRoundTripExportCallbackCustomCallsPass();
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
@@ -86,6 +86,8 @@ using ::mlir::success;
|
||||
using ::mlir::SymbolTable;
|
||||
using ::mlir::func::FuncOp;
|
||||
|
||||
using ::mlir::stablehlo::CustomCallOp;
|
||||
|
||||
using ::mlir::sdy::AxisRefAttr;
|
||||
using ::mlir::sdy::DimensionShardingAttr;
|
||||
using ::mlir::sdy::kShardingAttr;
|
||||
@@ -197,6 +199,7 @@ class ExportMhloShardingsPass
|
||||
|
||||
void runOnOperation() final {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
|
||||
mlir::SymbolTableCollection symbolTableCollection;
|
||||
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);
|
||||
|
||||
@@ -208,10 +211,10 @@ class ExportMhloShardingsPass
|
||||
}
|
||||
}
|
||||
|
||||
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
|
||||
// If they have a sharding annotation, we need to move it into
|
||||
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
|
||||
moduleOp.walk([&](mlir::stablehlo::CustomCallOp customCall) {
|
||||
moduleOp.walk([&](CustomCallOp customCall) {
|
||||
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
|
||||
// If they have a sharding annotation, we need to move it into
|
||||
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
|
||||
StringRef callTargetName = customCall.getCallTargetName();
|
||||
if (callTargetName != "mhlo.erf" && callTargetName != "mhlo.topk") {
|
||||
return;
|
||||
|
||||
@@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
|
||||
@@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
|
||||
pm.addPass(createMhloRoundTripShardMapExportPass());
|
||||
pm.addPass(createExportNamedComputationsPass());
|
||||
pm.addPass(createExportMhloShardingsPass());
|
||||
pm.addPass(createMhloRoundTripExportCallbackCustomCallsPass());
|
||||
}
|
||||
|
||||
void registerMhloExportPipeline() {
|
||||
|
||||
@@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
|
||||
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
|
||||
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h"
|
||||
@@ -36,6 +37,7 @@ limitations under the License.
|
||||
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
|
||||
@@ -66,12 +68,14 @@ int main(int argc, char** argv) {
|
||||
|
||||
xla::sdy::registerMhloExportPipeline();
|
||||
xla::sdy::registerMhloExportShardingsPass();
|
||||
xla::sdy::registerMhloRoundTripExportCallbackCustomCallsPass();
|
||||
xla::sdy::registerMhloRoundTripShardMapExportPass();
|
||||
xla::sdy::registerExportNamedComputationsPass();
|
||||
xla::sdy::registerExportOpsPass();
|
||||
|
||||
xla::sdy::registerSdyRoundTripMhloToHloToMhloPass();
|
||||
xla::sdy::registerSdyRoundTripExportShardyAttrsPass();
|
||||
xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass();
|
||||
xla::sdy::registerSdyRoundTripImportShardyAttrsPass();
|
||||
xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass();
|
||||
xla::sdy::registerSdyRoundTripExportOpsPass();
|
||||
|
||||
@@ -126,6 +126,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "import_callback_custom_calls",
|
||||
srcs = ["import_callback_custom_calls.cc"],
|
||||
hdrs = ["import_callback_custom_calls.h"],
|
||||
deps = [
|
||||
"//xla/service/spmd/shardy:utils",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@shardy//shardy/dialect/sdy/ir:dialect",
|
||||
"@stablehlo//:stablehlo_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pipelines",
|
||||
srcs = ["pipelines.cc"],
|
||||
@@ -133,6 +149,7 @@ cc_library(
|
||||
deps = [
|
||||
":export_ops",
|
||||
":export_shardy_attrs",
|
||||
":import_callback_custom_calls",
|
||||
":import_shardy_attrs",
|
||||
":remove_size_one_axes",
|
||||
":shard_map_export",
|
||||
@@ -143,6 +160,5 @@ cc_library(
|
||||
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
91
third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc
vendored
Normal file
91
third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "xla/service/spmd/shardy/utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace sdy {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mlir::ModuleOp;
|
||||
using ::mlir::StringRef;
|
||||
using ::mlir::stablehlo::CustomCallOp;
|
||||
|
||||
class SdyRoundTripImportCallbackCustomCallsPass
|
||||
: public mlir::PassWrapper<SdyRoundTripImportCallbackCustomCallsPass,
|
||||
mlir::OperationPass<ModuleOp>> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
SdyRoundTripImportCallbackCustomCallsPass)
|
||||
|
||||
void runOnOperation() final {
|
||||
getOperation().walk([&](CustomCallOp op) {
|
||||
if (op->getNumResults() != 0 || !isPythonCallbackCustomCall(op)) {
|
||||
return;
|
||||
}
|
||||
mlir::IRRewriter rewriter(op);
|
||||
// Shardy needs at least one op result to have a sharding annotation.
|
||||
// Since the callback has no results, and we need to say the callbacks
|
||||
// have a maximal sharding, we add a dummy result and set the result
|
||||
// layout to the 0th operand layout.
|
||||
CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes(
|
||||
op, op->getOperand(0).getType(), rewriter);
|
||||
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr(
|
||||
{op.getOperandLayoutsAttr().getValue().front()}));
|
||||
rewriter.eraseOp(op);
|
||||
});
|
||||
}
|
||||
|
||||
StringRef getArgument() const override {
|
||||
return "xla-sdy-round-trip-import-callback-custom-calls";
|
||||
}
|
||||
|
||||
StringRef getDescription() const override {
|
||||
return "Modifies the return types of XLA host callback custom calls to be "
|
||||
"compatible with SDY";
|
||||
}
|
||||
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const final {
|
||||
registry.insert<mlir::stablehlo::StablehloDialect>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass() {
|
||||
return std::make_unique<SdyRoundTripImportCallbackCustomCallsPass>();
|
||||
}
|
||||
|
||||
void registerSdyRoundTripImportCallbackCustomCallsPass() {
|
||||
mlir::registerPass(createSdyRoundTripImportCallbackCustomCallsPass);
|
||||
}
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
41
third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h
vendored
Normal file
41
third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace xla {
|
||||
namespace sdy {
|
||||
|
||||
// Creates the pass to modify the return types of XLA host callback custom calls
|
||||
// to be compatible with SDY.
|
||||
//
|
||||
// Shardy shardings require an op to have at least one result, and the XLA host
|
||||
// callback custom calls are not guaranteed to return a value.
|
||||
// To allow the custom calls to have a maximal sharding, we change the return
|
||||
// type to return a dummy value.
|
||||
std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass();
|
||||
|
||||
// Registers the xla-sdy-round-trip-import-callback-custom-calls pass.
|
||||
void registerSdyRoundTripImportCallbackCustomCallsPass();
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
|
||||
@@ -65,8 +65,6 @@ using ::mlir::StringRef;
|
||||
using ::mlir::SymbolTable;
|
||||
using ::mlir::func::FuncOp;
|
||||
|
||||
using ::mlir::stablehlo::CustomCallOp;
|
||||
|
||||
using ::mlir::sdy::kShardingAttr;
|
||||
using ::mlir::sdy::kShardingRuleAttr;
|
||||
using ::mlir::sdy::MeshAttr;
|
||||
@@ -74,6 +72,8 @@ using ::mlir::sdy::OpShardingRuleAttr;
|
||||
using ::mlir::sdy::TensorShardingAttr;
|
||||
using ::mlir::sdy::TensorShardingPerValueAttr;
|
||||
|
||||
namespace stablehlo = ::mlir::stablehlo;
|
||||
|
||||
// Builds the shardy attributes coming from Shardy previously. This means
|
||||
// the module was exported from Shardy and we are now round-tripping back.
|
||||
// This should happen after the meshes were created from the `ModuleOp` attrs
|
||||
@@ -108,13 +108,19 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
|
||||
if (!dictAttr) {
|
||||
return;
|
||||
}
|
||||
// `SendOp` and `RecvOp` can have a sharding when doing TPU callbacks
|
||||
// through JAX.
|
||||
if (mlir::isa<stablehlo::SendOp, stablehlo::RecvOp>(op)) {
|
||||
op->setAttr(kShardingAttr, parseStringAttr<TensorShardingPerValueAttr>(
|
||||
dictAttr, kShardingRoundTripAttr));
|
||||
}
|
||||
// NOTE: we are only setting the sharding on known custom-calls. For any
|
||||
// other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes
|
||||
// creates new instructions, copying over the operand's frontend attrs,
|
||||
// which may mean the shapes are wrong when the new instruction is a reshape
|
||||
// for example. This does mean we can't fully round-trip b/w HLO and MLIR
|
||||
// after SDY propagation.
|
||||
if (auto customCallOp = mlir::dyn_cast<CustomCallOp>(op)) {
|
||||
if (auto customCallOp = mlir::dyn_cast<stablehlo::CustomCallOp>(op)) {
|
||||
StringRef targetName = customCallOp.getCallTargetName();
|
||||
if (targetName == kFuncResultShardingTargetName) {
|
||||
// This is a temporary CustomCallOp that holds the sharding from a
|
||||
@@ -139,7 +145,8 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
|
||||
}
|
||||
if (targetName == kShardingCustomCallTargetName ||
|
||||
targetName == kSPMDFullToShardShapeCallTargetName ||
|
||||
targetName == kSPMDShardToFullShapeCallTargetName) {
|
||||
targetName == kSPMDShardToFullShapeCallTargetName ||
|
||||
isPythonCallbackCustomCall(customCallOp)) {
|
||||
customCallOp->setAttr(kShardingAttr,
|
||||
parseStringAttr<TensorShardingPerValueAttr>(
|
||||
dictAttr, kShardingRoundTripAttr));
|
||||
|
||||
@@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
|
||||
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
|
||||
@@ -49,6 +50,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {
|
||||
|
||||
void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) {
|
||||
addCommonPreImportPasses(pm);
|
||||
pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass());
|
||||
pm.addPass(createSdyRoundTripImportShardyAttrsPass());
|
||||
pm.addPass(createSdyRoundTripShardMapImportPass());
|
||||
pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass());
|
||||
|
||||
@@ -246,6 +246,74 @@ func.func @custom_call_erf_topk(
|
||||
return %1#0 : tensor<16x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @callback_transform_to_tuple
|
||||
func.func @callback_transform_to_tuple(%arg0: tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) -> (tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) {
|
||||
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
|
||||
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor<i64>, tensor<2xf64>) -> tuple<tensor<2xf64>>
|
||||
// CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple<tensor<2xf64>>) -> tensor<2xf64>
|
||||
// CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64>
|
||||
%1 = stablehlo.constant dense<56560393354880> : tensor<i64>
|
||||
%2 = stablehlo.custom_call @xla_python_cpu_callback(%1, %arg0) {api_version = 2 : i32, backend_config = "56560393354880", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{}]>]>, xla_shape = "(f64[2]{0})"} : (tensor<i64>, tensor<2xf64>) -> tensor<2xf64>
|
||||
return %2 : tensor<2xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @callback_no_result
|
||||
func.func private @callback_no_result(%arg0: tensor<f64>) {
|
||||
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
|
||||
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
|
||||
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
|
||||
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
|
||||
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
|
||||
// CHECK-SAME: result_layouts = []
|
||||
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> ()
|
||||
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
|
||||
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, tensor<f64>) -> tuple<>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @callback_result_unused
|
||||
func.func private @callback_result_unused(%arg0: tensor<f64>) {
|
||||
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
|
||||
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
|
||||
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
|
||||
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
|
||||
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
|
||||
// CHECK-SAME: result_layouts = []
|
||||
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> ()
|
||||
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
|
||||
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, tensor<f64>) -> tensor<i64>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @callback_tuple_result_token_used
|
||||
func.func public @callback_tuple_result_token_used(%arg0: !stablehlo.token, %arg1: tensor<2xi64>) -> !stablehlo.token {
|
||||
%c = stablehlo.constant dense<56238119409280> : tensor<i64>
|
||||
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
|
||||
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0, %arg1) {
|
||||
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238119409280",
|
||||
// CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}",
|
||||
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>],
|
||||
// CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>]
|
||||
// CHECK-SAME: } : (tensor<i64>, !stablehlo.token, tensor<2xi64>) -> tuple<!stablehlo.token>
|
||||
// CHECK-NEXT: %[[TOKEN:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] : (tuple<!stablehlo.token>) -> !stablehlo.token
|
||||
// CHECK-NEXT: return %[[TOKEN]] : !stablehlo.token
|
||||
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0, %arg1) {api_version = 2 : i32, backend_config = "56238119409280", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor<i64>, !stablehlo.token, tensor<2xi64>) -> tuple<!stablehlo.token>
|
||||
%1 = stablehlo.get_tuple_element %0[0] : (tuple<!stablehlo.token>) -> !stablehlo.token
|
||||
return %1 : !stablehlo.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @callback_no_tuple_result_used
|
||||
func.func @callback_no_tuple_result_used(%arg0: tensor<2xf64>) -> tensor<2xf64> {
|
||||
// CHECK-NEXT: %[[C:.*]] = stablehlo.constant
|
||||
// CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor<i64>, tensor<2xf64>) -> tuple<tensor<2xf64>>
|
||||
// CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple<tensor<2xf64>>) -> tensor<2xf64>
|
||||
// CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64>
|
||||
%c = stablehlo.constant dense<18990036333952> : tensor<i64>
|
||||
%0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "18990036333952", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{?}]>]>, xla_shape = "(f64[2]{0})"} : (tensor<i64>, tensor<2xf64>) -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func private @foo
|
||||
// CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}
|
||||
// CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) {
|
||||
|
||||
@@ -241,3 +241,18 @@ func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
|
||||
stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> ()
|
||||
return %arg0 : tensor<8x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @callback_no_result(%arg0: tensor<f64>) {
|
||||
// CHECK: %[[C:.*]] = sdy.constant
|
||||
// CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {
|
||||
// CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176",
|
||||
// CHECK-SAME: has_side_effect = true,
|
||||
// CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>],
|
||||
// CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>]
|
||||
// CHECK-SAME: } : (tensor<i64>, tensor<f64>) -> tensor<i64>
|
||||
%c = stablehlo.constant dense<56238273106176> : tensor<i64>
|
||||
stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []} : (tensor<i64>, tensor<f64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
24
third_party/xla/xla/service/spmd/shardy/utils.cc
vendored
24
third_party/xla/xla/service/spmd/shardy/utils.cc
vendored
@@ -30,9 +30,12 @@ limitations under the License.
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "shardy/dialect/sdy/ir/register.h"
|
||||
#include "shardy/dialect/sdy/ir/utils.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
|
||||
#include "xla/service/spmd/shardy/constants.h"
|
||||
|
||||
@@ -50,6 +53,7 @@ using ::mlir::StringRef;
|
||||
using xla::sdy::kFrontendAttributesAttr;
|
||||
|
||||
using ::mlir::func::FuncOp;
|
||||
using ::mlir::stablehlo::CustomCallOp;
|
||||
|
||||
DictionaryAttr getFrontendAttrs(Operation* op) {
|
||||
return op->getAttrOfType<DictionaryAttr>(kFrontendAttributesAttr);
|
||||
@@ -185,5 +189,25 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) {
|
||||
context->loadAllAvailableDialects();
|
||||
}
|
||||
|
||||
CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::IRRewriter& rewriter) {
|
||||
auto customCallOp = rewriter.create<CustomCallOp>(
|
||||
op.getLoc(), resultTypes, op.getOperands(), op.getCallTargetNameAttr(),
|
||||
op.getHasSideEffectAttr(), op.getBackendConfigAttr(),
|
||||
op.getApiVersionAttr(), op.getCalledComputations(),
|
||||
op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(),
|
||||
op.getOutputOperandAliases());
|
||||
customCallOp->setDiscardableAttrs(mlir::DictionaryAttr::get(
|
||||
op->getContext(), llvm::to_vector(op->getDiscardableAttrs())));
|
||||
return customCallOp;
|
||||
};
|
||||
|
||||
bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op) {
|
||||
mlir::StringRef targetName = op.getCallTargetName();
|
||||
return targetName == kPythonCpuCallbackCustomCallTargetName ||
|
||||
targetName == kPythonGpuCallbackCustomCallTargetName;
|
||||
}
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
|
||||
12
third_party/xla/xla/service/spmd/shardy/utils.h
vendored
12
third_party/xla/xla/service/spmd/shardy/utils.h
vendored
@@ -28,7 +28,10 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
|
||||
namespace xla {
|
||||
namespace sdy {
|
||||
@@ -101,6 +104,15 @@ std::optional<AttrTy> tryGetFrontendAttr(mlir::Operation* op,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Builds a new `stablehlo.custom_call` with the same operands and attributes
|
||||
// as `op` but with new `resultTypes`.
|
||||
mlir::stablehlo::CustomCallOp cloneCustomCallWithNewResultTypes(
|
||||
mlir::stablehlo::CustomCallOp op, mlir::TypeRange resultTypes,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
// Whether `op` is a Python callback custom call.
|
||||
bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op);
|
||||
|
||||
} // namespace sdy
|
||||
} // namespace xla
|
||||
|
||||
|
||||
Reference in New Issue
Block a user