mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
#HLODiff Add unmatched computations to summary->computation_diff_patterns
PiperOrigin-RevId: 775770118
This commit is contained in:
committed by
TensorFlower Gardener
parent
35a1ffb49c
commit
ce60a770af
@@ -186,15 +186,23 @@ FindConnectedComponents(
|
||||
absl::flat_hash_map<const HloComputation*, const ComputationSummary>
|
||||
computation_summary) {
|
||||
ConnectedComponentsFinder cc;
|
||||
std::vector<std::vector<const HloComputation*>> unmatched_computations;
|
||||
absl::flat_hash_map<uint64_t, std::vector<ComputationGroup>> result;
|
||||
for (const auto& [computation, computation_match_info] :
|
||||
computation_summary) {
|
||||
if (computation_match_info.main_matched_computation != nullptr) {
|
||||
cc.AddEdge(computation, computation_match_info.main_matched_computation);
|
||||
} else {
|
||||
// main_matched_computation is nullptr means all instructions in the
|
||||
// computation are unmatched.
|
||||
unmatched_computations.push_back({computation});
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<const HloComputation*>> connected_component_groups =
|
||||
cc.FindConnectedComponents();
|
||||
connected_component_groups.insert(connected_component_groups.end(),
|
||||
unmatched_computations.begin(),
|
||||
unmatched_computations.end());
|
||||
|
||||
for (const auto& component_group : connected_component_groups) {
|
||||
bool all_unchanged = true;
|
||||
|
||||
@@ -37,6 +37,7 @@ namespace {
|
||||
|
||||
using ::testing::ExplainMatchResult;
|
||||
using ::testing::FieldsAre;
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::Pair;
|
||||
using ::testing::Pointee;
|
||||
using ::testing::Property;
|
||||
@@ -375,6 +376,75 @@ TEST_F(HloDiffTest, FindConnectedComponentsWorks) {
|
||||
/*right_unmatched_instruction_count=*/6))));
|
||||
}
|
||||
|
||||
TEST_F(HloDiffTest, FindConnectedComponentsWorksForIsolatedComputations) {
|
||||
// Create left module with entry computation containing the following
|
||||
// structure:
|
||||
// [Param 0] ---> ┌-------┐
|
||||
// | add_0 |
|
||||
// [Param 1] ---> └-------┘
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module_l,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
ENTRY entry {
|
||||
parameter.0 = f32[] parameter(0)
|
||||
parameter.1 = f32[] parameter(1)
|
||||
add.0 = f32[] add(parameter.0, parameter.1)
|
||||
}
|
||||
)"));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<const HloGumgraph> graph_l,
|
||||
HloGumgraph::Create(module_l.get()));
|
||||
// Create left module with entry computation containing the following
|
||||
// structure:
|
||||
// [Const 0] ---> ┌-------┐
|
||||
// | sub_0 |
|
||||
// [Const 1] ---> └-------┘
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module_r,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
ENTRY entry {
|
||||
constant.0 = f32[] constant(0)
|
||||
constant.1 = f32[] constant(1)
|
||||
subtract.0 = f32[] subtract(constant.0, constant.1)
|
||||
}
|
||||
)"));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<const HloGumgraph> graph_r,
|
||||
HloGumgraph::Create(module_r.get()));
|
||||
auto mappings = std::make_unique<HloGumgraphMappings>();
|
||||
// Root nodes are matched by default before the matcher is called.
|
||||
mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(),
|
||||
MatcherType::kManual);
|
||||
std::unique_ptr<const DiffResult> diff_result =
|
||||
ConstructDiffResult(*graph_l, *graph_r, *mappings);
|
||||
std::unique_ptr<const DiffSummary> diff_summary =
|
||||
ConstructDiffSummary(*module_l, *module_r, *diff_result);
|
||||
EXPECT_THAT(diff_summary->computation_diff_patterns,
|
||||
UnorderedElementsAre(
|
||||
FieldsAre(
|
||||
/*fingerprint=*/619838372110990418U,
|
||||
/*computation_groups=*/
|
||||
UnorderedElementsAre(FieldsAre(
|
||||
/*left_computations=*/UnorderedElementsAre(Pointee(
|
||||
Property(&HloComputation::name, "entry"))),
|
||||
/*right_computations=*/IsEmpty())),
|
||||
/*diff_metrics=*/
|
||||
FieldsAre(/*changed_instruction_count=*/0,
|
||||
/*left_unmatched_instruction_count=*/3,
|
||||
/*right_unmatched_instruction_count=*/0)),
|
||||
FieldsAre(
|
||||
/*fingerprint=*/591642684880638740U,
|
||||
/*computation_groups=*/
|
||||
UnorderedElementsAre(FieldsAre(
|
||||
/*left_computations=*/IsEmpty(),
|
||||
/*right_computations=*/UnorderedElementsAre(Pointee(
|
||||
Property(&HloComputation::name, "entry"))))),
|
||||
/*diff_metrics=*/
|
||||
FieldsAre(/*changed_instruction_count=*/0,
|
||||
/*left_unmatched_instruction_count=*/0,
|
||||
/*right_unmatched_instruction_count=*/3))));
|
||||
}
|
||||
|
||||
MATCHER(EqualsComputationGroup, "") {
|
||||
const ComputationGroup& a = std::get<0>(arg);
|
||||
const ComputationGroup& b = std::get<1>(arg);
|
||||
|
||||
Reference in New Issue
Block a user