mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
committed by
TensorFlower Gardener
parent
d9f02c1e55
commit
f605aa89b3
@@ -1096,14 +1096,6 @@ absl::StatusOr<TensorValue> EmitPad(
|
||||
.getResult());
|
||||
}
|
||||
|
||||
absl::StatusOr<TensorValue> EmitTiledDynamicSlice(
|
||||
mlir::ImplicitLocOpBuilder& b,
|
||||
const TiledHloInstruction& tiled_dynamic_slice,
|
||||
absl::flat_hash_map<const TiledHloInstruction*, TensorValue>& values) {
|
||||
// Slicing happens in `ComputeOffsetsForTile` when this value is emitted.
|
||||
return values[tiled_dynamic_slice.operand(0)];
|
||||
}
|
||||
|
||||
absl::StatusOr<TensorValue> EmitTiledHloInstruction(
|
||||
mlir::ImplicitLocOpBuilder& b, const HloFusionInstruction* fusion,
|
||||
const TiledHloInstruction& tiled_hlo,
|
||||
@@ -1236,7 +1228,9 @@ absl::StatusOr<TensorValue> EmitTiledHloInstruction(
|
||||
}
|
||||
|
||||
if (hlo->opcode() == HloOpcode::kDynamicSlice) {
|
||||
return EmitTiledDynamicSlice(b, tiled_hlo, values);
|
||||
// Dynamic slice is implemented as a load and does not require any further
|
||||
// processing.
|
||||
return values[tiled_hlo.operand(0)];
|
||||
}
|
||||
|
||||
return absl::UnimplementedError(
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "xla/backends/gpu/codegen/triton/support.h"
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
@@ -652,8 +653,10 @@ CodegenDecision IsTritonSupportedInstructionImpl(
|
||||
case HloOpcode::kParameter:
|
||||
return CodegenDecision::Allow();
|
||||
case HloOpcode::kDynamicSlice:
|
||||
return IsTritonSupportedDynamicSlice(
|
||||
*Cast<HloDynamicSliceInstruction>(&instr));
|
||||
// TODO(b/417172838): enable this once we confirm that no benchmarks were
|
||||
// regressed.
|
||||
return CodegenDecision::Forbid(
|
||||
"dynamic slice is supported but not enabled yet");
|
||||
case HloOpcode::kBitcast:
|
||||
if (ShapeUtil::ElementsIn(instr.operand(0)->shape()) !=
|
||||
ShapeUtil::ElementsIn(instr.shape())) {
|
||||
@@ -701,6 +704,7 @@ namespace internal {
|
||||
bool IsTritonUnsupportedOpcode(HloOpcode opcode) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kGather:
|
||||
case HloOpcode::kRaggedDot:
|
||||
@@ -739,26 +743,6 @@ absl::Status EnsureTritonSupportsComputeCapability(
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
CodegenDecision IsTritonSupportedDynamicSlice(
|
||||
const HloDynamicSliceInstruction& instr) {
|
||||
for (const HloInstruction* index_operand : instr.index_operands()) {
|
||||
switch (index_operand->shape().element_type()) {
|
||||
case S8:
|
||||
case S16:
|
||||
case S32:
|
||||
case S64:
|
||||
break; // supported
|
||||
default:
|
||||
return CodegenDecision::Forbid(
|
||||
"Dynamic slice is only supported S8, S16, S32, or S64 offsets.");
|
||||
}
|
||||
}
|
||||
if (instr.shape().element_type() == PrimitiveType::S4) {
|
||||
return CodegenDecision::Forbid("S4 is not supported.");
|
||||
}
|
||||
return CodegenDecision::Allow();
|
||||
}
|
||||
|
||||
CodegenDecision IsTritonSupportedInstruction(
|
||||
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
|
||||
CodegenDecision decision =
|
||||
|
||||
@@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/service/instruction_fusion.h"
|
||||
#include "xla/shape.h"
|
||||
@@ -67,11 +66,6 @@ CodegenDecision IsTritonSupportedComputation(
|
||||
// `kTritonGemmFusionKind`.
|
||||
bool IsTritonFusedComputation(const HloComputation& computation);
|
||||
|
||||
// TODO(b/393299275): this function is only exposed for
|
||||
// triton_tiling_propagation.cc. If possible it should be removed.
|
||||
CodegenDecision IsTritonSupportedDynamicSlice(
|
||||
const HloDynamicSliceInstruction& instr);
|
||||
|
||||
namespace internal {
|
||||
// TODO(b/363981282): Remove the function below once all ops are tested via
|
||||
// HLOs. This is exposed for testing purposes only and will be removed in the
|
||||
|
||||
@@ -156,15 +156,6 @@ std::vector<xla::PrimitiveType> AllOpSupportedTypes(HloOpcode opcode) {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<xla::PrimitiveType> AllIntegralDataTypes() {
|
||||
std::vector<xla::PrimitiveType> result;
|
||||
absl::c_copy_if(AllXlaDataTypes(), std::back_inserter(result),
|
||||
[&](PrimitiveType data_type) {
|
||||
return primitive_util::IsIntegralType(data_type);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<PrecisionConfig::Algorithm> AllPrecisionAlgorithms() {
|
||||
std::vector<PrecisionConfig::Algorithm> algorithms;
|
||||
const tsl::protobuf::EnumDescriptor* algorithm_descriptor =
|
||||
@@ -3099,54 +3090,6 @@ INSTANTIATE_TEST_SUITE_P(SortSuite, SortTest,
|
||||
AllTestCombinationsForOpcodes({HloOpcode::kSort}),
|
||||
TritonSupportTestTypeAndOpcodeAndDeviceToString);
|
||||
|
||||
using DynamicSliceTest = TritonSupportTestWithTypeAndDeviceParam;
|
||||
|
||||
TEST_P(DynamicSliceTest, OperandTypes) {
|
||||
auto [data_type, cc] = GetParam();
|
||||
const std::string kHloTestTemplate = R"(
|
||||
ENTRY triton_computation {
|
||||
operand = $0[256,256] parameter(0)
|
||||
start_1 = s32[] parameter(1)
|
||||
start_2 = s32[] constant(0)
|
||||
ROOT dynamic_slice_op = $0[32,256] dynamic-slice(operand, start_1, start_2),
|
||||
dynamic_slice_sizes={32,256}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
|
||||
kHloTestTemplate, data_type,
|
||||
HloOpcode::kDynamicSlice));
|
||||
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
DynamicSliceSuite, DynamicSliceTest,
|
||||
::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()),
|
||||
::testing::ValuesIn(AllDevicesToTest())),
|
||||
TritonSupportTestTypeAndDeviceToString);
|
||||
|
||||
using DynamicSliceOffsetTypesTest = TritonSupportTestWithTypeAndDeviceParam;
|
||||
|
||||
TEST_P(DynamicSliceOffsetTypesTest, DynamicSlice2D) {
|
||||
auto [data_type, cc] = GetParam();
|
||||
const std::string kHloTestTemplate = R"(
|
||||
ENTRY triton_computation {
|
||||
operand = f32[256,256] parameter(0)
|
||||
start_1 = $0[] parameter(1)
|
||||
start_2 = $0[] parameter(2)
|
||||
ROOT dynamic_slice_op = f32[32,64] dynamic-slice(operand, start_1, start_2),
|
||||
dynamic_slice_sizes={32,64}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
|
||||
kHloTestTemplate, data_type,
|
||||
HloOpcode::kDynamicSlice));
|
||||
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
DynamicSliceOffsetTypesSuite, DynamicSliceOffsetTypesTest,
|
||||
::testing::Combine(::testing::ValuesIn(AllIntegralDataTypes()),
|
||||
::testing::ValuesIn(AllDevicesToTest())),
|
||||
TritonSupportTestTypeAndDeviceToString);
|
||||
|
||||
using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
|
||||
|
||||
TEST_P(RecvOpsTest, RecvAndRecvDone) {
|
||||
@@ -3534,6 +3477,7 @@ constexpr std::array kUnsupportedOps = {
|
||||
// clang-format off
|
||||
// go/keep-sorted start
|
||||
HloOpcode::kDynamicReshape,
|
||||
HloOpcode::kDynamicSlice,
|
||||
HloOpcode::kDynamicUpdateSlice,
|
||||
HloOpcode::kGather,
|
||||
HloOpcode::kRaggedDot,
|
||||
@@ -3593,7 +3537,6 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
|
||||
ret.emplace(HloOpcode::kCustomCall);
|
||||
ret.emplace(HloOpcode::kDomain);
|
||||
ret.emplace(HloOpcode::kDot);
|
||||
ret.emplace(HloOpcode::kDynamicSlice);
|
||||
ret.emplace(HloOpcode::kFft);
|
||||
ret.emplace(HloOpcode::kFusion);
|
||||
ret.emplace(HloOpcode::kGetDimensionSize);
|
||||
|
||||
@@ -44,7 +44,6 @@ limitations under the License.
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/hlo/ir/hlo_print_options.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/service/gpu/backend_configs.pb.h"
|
||||
#include "xla/service/gpu/cublas_padding_requirements.h"
|
||||
#include "xla/service/gpu/ir_emission_utils.h"
|
||||
@@ -274,36 +273,6 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
|
||||
std::get<DotRequirements>(combined_reqs)};
|
||||
}
|
||||
|
||||
// Checks if a dynamic slice can be fused.
|
||||
bool CanFuseDynamicSlice(const HloDynamicSliceInstruction& dynamic_slice,
|
||||
const se::GpuComputeCapability& gpu_version) {
|
||||
if (CodegenDecision decision =
|
||||
IsTritonSupportedInstruction(dynamic_slice, gpu_version);
|
||||
!decision.CanFuse()) {
|
||||
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
|
||||
<< " to the output due to the decision: " << decision.Explain();
|
||||
return false;
|
||||
}
|
||||
// TODO(b/417172838): this check replicates the legacy emitter behavior.
|
||||
// New emitter might support all dimensions but we should verify that.
|
||||
const HloInstruction* input = dynamic_slice.operand(0);
|
||||
Layout in_layout = input->shape().layout();
|
||||
int64_t majormost_dim_id =
|
||||
in_layout.minor_to_major(in_layout.minor_to_major().size() - 1);
|
||||
for (int i = 0; i < input->shape().dimensions().size(); ++i) {
|
||||
if (i == majormost_dim_id) {
|
||||
continue;
|
||||
}
|
||||
if (input->shape().dimensions(i) != dynamic_slice.slice_sizes(i)) {
|
||||
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
|
||||
<< " to the output due to the unsupported dynamic slice on "
|
||||
"non-major-most dimension.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class FusionPlanBuilder {
|
||||
public:
|
||||
// Builds and returns the FusionPlan. Clears internal state.
|
||||
@@ -445,12 +414,10 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
|
||||
// replaces unsupported F8E8M0FNU with u8. We should have a more principled
|
||||
// way check if we will be able to emit the triton code for the fusion.
|
||||
if (original_hlo.opcode() == HloOpcode::kDynamicSlice) {
|
||||
const HloDynamicSliceInstruction& dynamic_slice =
|
||||
*Cast<HloDynamicSliceInstruction>(&original_hlo);
|
||||
if (!CanFuseDynamicSlice(dynamic_slice, gpu_version)) {
|
||||
fusion_builder.SetShouldFuseNode(node_id, false);
|
||||
continue;
|
||||
}
|
||||
// TODO(b/417172838): support dynamic slice op.
|
||||
fusion_builder.SetShouldFuseNode(node_id, false);
|
||||
LOG(INFO) << "Not fusing dynamic slice: " << original_hlo.ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
|
||||
|
||||
@@ -264,7 +264,8 @@ ENTRY e {
|
||||
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(GemmFusionTest, DynamicSliceIsFused) {
|
||||
// TODO(b/417172838): support dynamic slice op.
|
||||
TEST_F(GemmFusionTest, DISABLED_DynamicSliceIsFused) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
ENTRY e {
|
||||
@@ -288,7 +289,8 @@ ENTRY e {
|
||||
m::Parameter(), m::Constant()))));
|
||||
}
|
||||
|
||||
TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
|
||||
// TODO(b/417172838): support dynamic slice op.
|
||||
TEST_F(GemmFusionTest, DISABLED_DynamicSlicesAreFusedEvenIfTheyShareIndices) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
ENTRY e {
|
||||
@@ -319,7 +321,8 @@ ENTRY e {
|
||||
m::Parameter(), m::Parameter()))));
|
||||
}
|
||||
|
||||
TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
|
||||
// TODO(b/417172838): support dynamic slice op.
|
||||
TEST_F(GemmFusionTest, DISABLED_DoNotFuseDynamicSliceOfNonMajorFragments) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
ENTRY e {
|
||||
@@ -338,7 +341,9 @@ ENTRY e {
|
||||
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
|
||||
// TODO(b/417172838): support dynamic slice op.
|
||||
TEST_F(GemmFusionTest,
|
||||
DISABLED_CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
ENTRY e {
|
||||
|
||||
@@ -914,7 +914,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo,
|
||||
properties);
|
||||
} else if (hlo.opcode() == HloOpcode::kDynamicSlice &&
|
||||
direction == TransformDirection::kOutputToInput) {
|
||||
if (CodegenDecision decision = IsTritonSupportedDynamicSlice(
|
||||
if (CodegenDecision decision = legacy_triton::IsTritonSupportedDynamicSlice(
|
||||
*Cast<HloDynamicSliceInstruction>(&hlo));
|
||||
!decision.CanFuse()) {
|
||||
// CodegenDecision is actually the same type as FusionDecision.
|
||||
|
||||
Reference in New Issue
Block a user