From 4aa90bfd39832570e84ab049f4c099359f2f608a Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 31 Oct 2017 16:47:47 -0700 Subject: [PATCH] [XLA] Add HLO matchers that check parameter numbers and GTE indices. This lets you do EXPECT_THAT(foo, op::Parameter(42)); and EXPECT_THAT(bar, op::GetTupleElement(baz, 8)); PiperOrigin-RevId: 174113597 --- .../compiler/xla/service/hlo_matchers.cc | 29 ++++++++ .../compiler/xla/service/hlo_matchers.h | 69 ++++++++++++++++++- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 0660d5a1820..4255d608662 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const { } } +bool HloParameterMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->parameter_number() != parameter_number_) { + *listener << "has wrong parameter number (got " + << instruction->parameter_number() << ", want " + << parameter_number_ << ")"; + return false; + } + return true; +} + +bool HloGetTupleElementMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->tuple_index() != tuple_index_) { + *listener << "has wrong tuple index (got " << instruction->tuple_index() + << ", want " << tuple_index_ << ")"; + return false; + } + return true; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index bc5ed029a45..4d4010b0253 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface { std::vector<::testing::Matcher> operands_; }; +// Custom matcher for parameters, which accepts a parameter number. +class HloParameterMatcher : public HloMatcher { + public: + explicit HloParameterMatcher(int64 parameter_number) + : HloMatcher(HloOpcode::kParameter, /*operands=*/{}), + parameter_number_(parameter_number) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 parameter_number_; +}; + +// Custom matcher for get-tuple-element instructions, which accepts a tuple +// index to match. +class HloGetTupleElementMatcher : public HloMatcher { + public: + explicit HloGetTupleElementMatcher( + ::testing::Matcher operand, int64 tuple_index) + : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), + tuple_index_(tuple_index) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 tuple_index_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -72,7 +102,6 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); -HLO_MATCHER(GetTupleElement); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -90,7 +119,6 @@ HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); -HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(Reduce); @@ -115,6 +143,43 @@ HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); HLO_MATCHER(While); + +// The special cases below let you check additional information about the +// HloInstruction, beyond just its opcode and operands. In all cases you can +// still use the generic matcher which doesn't check this info. +// +// Feel free to add additional custom matchers below. + +// - Parameter(N) matches parameter number N. +// - Parameter() matches any parameter. +inline ::testing::Matcher Parameter( + int64 parameter_number) { + return ::testing::MakeMatcher( + new ::xla::testing::HloParameterMatcher(parameter_number)); +} +inline ::testing::Matcher Parameter() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); +} + +// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th +// tuple element of operand, while GetTupleElement(operand) matches any GTE +// operation on operand, and GetTupleElement() matches any GTE operation at all. +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand, int64 tuple_index) { + return ::testing::MakeMatcher( + new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index)); +} +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand) { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand})); +} +inline ::testing::Matcher GetTupleElement() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); +} + #undef HLO_MATCHER } // namespace opcode_matchers