PiperOrigin-RevId: 839259976
This commit is contained in:
A. Unique TensorFlower
2025-12-02 07:28:07 -08:00
committed by TensorFlower Gardener
parent d9f02c1e55
commit f605aa89b3
7 changed files with 24 additions and 137 deletions

View File

@@ -1096,14 +1096,6 @@ absl::StatusOr<TensorValue> EmitPad(
.getResult());
}
absl::StatusOr<TensorValue> EmitTiledDynamicSlice(
mlir::ImplicitLocOpBuilder& b,
const TiledHloInstruction& tiled_dynamic_slice,
absl::flat_hash_map<const TiledHloInstruction*, TensorValue>& values) {
// Slicing happens in `ComputeOffsetsForTile` when this value is emitted.
return values[tiled_dynamic_slice.operand(0)];
}
absl::StatusOr<TensorValue> EmitTiledHloInstruction(
mlir::ImplicitLocOpBuilder& b, const HloFusionInstruction* fusion,
const TiledHloInstruction& tiled_hlo,
@@ -1236,7 +1228,9 @@ absl::StatusOr<TensorValue> EmitTiledHloInstruction(
}
if (hlo->opcode() == HloOpcode::kDynamicSlice) {
return EmitTiledDynamicSlice(b, tiled_hlo, values);
// Dynamic slice is implemented as a load and does not require any further
// processing.
return values[tiled_hlo.operand(0)];
}
return absl::UnimplementedError(

View File

@@ -16,6 +16,7 @@ limitations under the License.
#include "xla/backends/gpu/codegen/triton/support.h"
#include <string>
#include <variant>
#include <vector>
#include "absl/algorithm/container.h"
@@ -652,8 +653,10 @@ CodegenDecision IsTritonSupportedInstructionImpl(
case HloOpcode::kParameter:
return CodegenDecision::Allow();
case HloOpcode::kDynamicSlice:
return IsTritonSupportedDynamicSlice(
*Cast<HloDynamicSliceInstruction>(&instr));
// TODO(b/417172838): enable this once we confirm that no benchmarks were
// regressed.
return CodegenDecision::Forbid(
"dynamic slice is supported but not enabled yet");
case HloOpcode::kBitcast:
if (ShapeUtil::ElementsIn(instr.operand(0)->shape()) !=
ShapeUtil::ElementsIn(instr.shape())) {
@@ -701,6 +704,7 @@ namespace internal {
bool IsTritonUnsupportedOpcode(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kDynamicReshape:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kGather:
case HloOpcode::kRaggedDot:
@@ -739,26 +743,6 @@ absl::Status EnsureTritonSupportsComputeCapability(
return absl::OkStatus();
}
CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr) {
for (const HloInstruction* index_operand : instr.index_operands()) {
switch (index_operand->shape().element_type()) {
case S8:
case S16:
case S32:
case S64:
break; // supported
default:
return CodegenDecision::Forbid(
"Dynamic slice is only supported S8, S16, S32, or S64 offsets.");
}
}
if (instr.shape().element_type() == PrimitiveType::S4) {
return CodegenDecision::Forbid("S4 is not supported.");
}
return CodegenDecision::Allow();
}
CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
CodegenDecision decision =

View File

@@ -21,7 +21,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/instruction_fusion.h"
#include "xla/shape.h"
@@ -67,11 +66,6 @@ CodegenDecision IsTritonSupportedComputation(
// `kTritonGemmFusionKind`.
bool IsTritonFusedComputation(const HloComputation& computation);
// TODO(b/393299275): this function is only exposed for
// triton_tiling_propagation.cc. If possible it should be removed.
CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr);
namespace internal {
// TODO(b/363981282): Remove the function below once all ops are tested via
// HLOs. This is exposed for testing purposes only and will be removed in the

View File

@@ -156,15 +156,6 @@ std::vector<xla::PrimitiveType> AllOpSupportedTypes(HloOpcode opcode) {
return result;
}
std::vector<xla::PrimitiveType> AllIntegralDataTypes() {
std::vector<xla::PrimitiveType> result;
absl::c_copy_if(AllXlaDataTypes(), std::back_inserter(result),
[&](PrimitiveType data_type) {
return primitive_util::IsIntegralType(data_type);
});
return result;
}
std::vector<PrecisionConfig::Algorithm> AllPrecisionAlgorithms() {
std::vector<PrecisionConfig::Algorithm> algorithms;
const tsl::protobuf::EnumDescriptor* algorithm_descriptor =
@@ -3099,54 +3090,6 @@ INSTANTIATE_TEST_SUITE_P(SortSuite, SortTest,
AllTestCombinationsForOpcodes({HloOpcode::kSort}),
TritonSupportTestTypeAndOpcodeAndDeviceToString);
using DynamicSliceTest = TritonSupportTestWithTypeAndDeviceParam;
TEST_P(DynamicSliceTest, OperandTypes) {
auto [data_type, cc] = GetParam();
const std::string kHloTestTemplate = R"(
ENTRY triton_computation {
operand = $0[256,256] parameter(0)
start_1 = s32[] parameter(1)
start_2 = s32[] constant(0)
ROOT dynamic_slice_op = $0[32,256] dynamic-slice(operand, start_1, start_2),
dynamic_slice_sizes={32,256}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
kHloTestTemplate, data_type,
HloOpcode::kDynamicSlice));
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
}
INSTANTIATE_TEST_SUITE_P(
DynamicSliceSuite, DynamicSliceTest,
::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()),
::testing::ValuesIn(AllDevicesToTest())),
TritonSupportTestTypeAndDeviceToString);
using DynamicSliceOffsetTypesTest = TritonSupportTestWithTypeAndDeviceParam;
TEST_P(DynamicSliceOffsetTypesTest, DynamicSlice2D) {
auto [data_type, cc] = GetParam();
const std::string kHloTestTemplate = R"(
ENTRY triton_computation {
operand = f32[256,256] parameter(0)
start_1 = $0[] parameter(1)
start_2 = $0[] parameter(2)
ROOT dynamic_slice_op = f32[32,64] dynamic-slice(operand, start_1, start_2),
dynamic_slice_sizes={32,64}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
kHloTestTemplate, data_type,
HloOpcode::kDynamicSlice));
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
}
INSTANTIATE_TEST_SUITE_P(
DynamicSliceOffsetTypesSuite, DynamicSliceOffsetTypesTest,
::testing::Combine(::testing::ValuesIn(AllIntegralDataTypes()),
::testing::ValuesIn(AllDevicesToTest())),
TritonSupportTestTypeAndDeviceToString);
using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
TEST_P(RecvOpsTest, RecvAndRecvDone) {
@@ -3534,6 +3477,7 @@ constexpr std::array kUnsupportedOps = {
// clang-format off
// go/keep-sorted start
HloOpcode::kDynamicReshape,
HloOpcode::kDynamicSlice,
HloOpcode::kDynamicUpdateSlice,
HloOpcode::kGather,
HloOpcode::kRaggedDot,
@@ -3593,7 +3537,6 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
ret.emplace(HloOpcode::kCustomCall);
ret.emplace(HloOpcode::kDomain);
ret.emplace(HloOpcode::kDot);
ret.emplace(HloOpcode::kDynamicSlice);
ret.emplace(HloOpcode::kFft);
ret.emplace(HloOpcode::kFusion);
ret.emplace(HloOpcode::kGetDimensionSize);

View File

@@ -44,7 +44,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/layout.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
#include "xla/service/gpu/ir_emission_utils.h"
@@ -274,36 +273,6 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
std::get<DotRequirements>(combined_reqs)};
}
// Checks if a dynamic slice can be fused.
bool CanFuseDynamicSlice(const HloDynamicSliceInstruction& dynamic_slice,
const se::GpuComputeCapability& gpu_version) {
if (CodegenDecision decision =
IsTritonSupportedInstruction(dynamic_slice, gpu_version);
!decision.CanFuse()) {
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
<< " to the output due to the decision: " << decision.Explain();
return false;
}
// TODO(b/417172838): this check replicates the legacy emitter behavior.
// New emitter might support all dimensions but we should verify that.
const HloInstruction* input = dynamic_slice.operand(0);
Layout in_layout = input->shape().layout();
int64_t majormost_dim_id =
in_layout.minor_to_major(in_layout.minor_to_major().size() - 1);
for (int i = 0; i < input->shape().dimensions().size(); ++i) {
if (i == majormost_dim_id) {
continue;
}
if (input->shape().dimensions(i) != dynamic_slice.slice_sizes(i)) {
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
<< " to the output due to the unsupported dynamic slice on "
"non-major-most dimension.";
return false;
}
}
return true;
}
class FusionPlanBuilder {
public:
// Builds and returns the FusionPlan. Clears internal state.
@@ -445,12 +414,10 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
// replaces unsupported F8E8M0FNU with u8. We should have a more principled
// way check if we will be able to emit the triton code for the fusion.
if (original_hlo.opcode() == HloOpcode::kDynamicSlice) {
const HloDynamicSliceInstruction& dynamic_slice =
*Cast<HloDynamicSliceInstruction>(&original_hlo);
if (!CanFuseDynamicSlice(dynamic_slice, gpu_version)) {
fusion_builder.SetShouldFuseNode(node_id, false);
continue;
}
// TODO(b/417172838): support dynamic slice op.
fusion_builder.SetShouldFuseNode(node_id, false);
LOG(INFO) << "Not fusing dynamic slice: " << original_hlo.ToString();
continue;
}
auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(

View File

@@ -264,7 +264,8 @@ ENTRY e {
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
}
TEST_F(GemmFusionTest, DynamicSliceIsFused) {
// TODO(b/417172838): support dynamic slice op.
TEST_F(GemmFusionTest, DISABLED_DynamicSliceIsFused) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
@@ -288,7 +289,8 @@ ENTRY e {
m::Parameter(), m::Constant()))));
}
TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
// TODO(b/417172838): support dynamic slice op.
TEST_F(GemmFusionTest, DISABLED_DynamicSlicesAreFusedEvenIfTheyShareIndices) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
@@ -319,7 +321,8 @@ ENTRY e {
m::Parameter(), m::Parameter()))));
}
TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
// TODO(b/417172838): support dynamic slice op.
TEST_F(GemmFusionTest, DISABLED_DoNotFuseDynamicSliceOfNonMajorFragments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
@@ -338,7 +341,9 @@ ENTRY e {
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
}
TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
// TODO(b/417172838): support dynamic slice op.
TEST_F(GemmFusionTest,
DISABLED_CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {

View File

@@ -914,7 +914,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo,
properties);
} else if (hlo.opcode() == HloOpcode::kDynamicSlice &&
direction == TransformDirection::kOutputToInput) {
if (CodegenDecision decision = IsTritonSupportedDynamicSlice(
if (CodegenDecision decision = legacy_triton::IsTritonSupportedDynamicSlice(
*Cast<HloDynamicSliceInstruction>(&hlo));
!decision.CanFuse()) {
// CodegenDecision is actually the same type as FusionDecision.