From ce60a770afc4a74b7a7b544fcd956a0f26df8f5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Jun 2025 11:27:36 -0700 Subject: [PATCH] #HLODiff Add unmatched computations to summary->computation_diff_patterns PiperOrigin-RevId: 775770118 --- .../hlo/tools/hlo_diff/hlo_diff_summary.cc | 8 +++ .../tools/hlo_diff/hlo_diff_summary_test.cc | 70 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc index cf6b528b29c..c9a0a13b50e 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc @@ -186,15 +186,23 @@ FindConnectedComponents( absl::flat_hash_map computation_summary) { ConnectedComponentsFinder cc; + std::vector> unmatched_computations; absl::flat_hash_map> result; for (const auto& [computation, computation_match_info] : computation_summary) { if (computation_match_info.main_matched_computation != nullptr) { cc.AddEdge(computation, computation_match_info.main_matched_computation); + } else { + // main_matched_computation is nullptr means all instructions in the + // computation are unmatched. + unmatched_computations.push_back({computation}); } } std::vector> connected_component_groups = cc.FindConnectedComponents(); + connected_component_groups.insert(connected_component_groups.end(), + unmatched_computations.begin(), + unmatched_computations.end()); for (const auto& component_group : connected_component_groups) { bool all_unchanged = true; diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc index 5b325a5120b..e66d1db8965 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc @@ -37,6 +37,7 @@ namespace { using ::testing::ExplainMatchResult; using ::testing::FieldsAre; +using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Pointee; using ::testing::Property; @@ -375,6 +376,75 @@ TEST_F(HloDiffTest, FindConnectedComponentsWorks) { /*right_unmatched_instruction_count=*/6)))); } +TEST_F(HloDiffTest, FindConnectedComponentsWorksForIsolatedComputations) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Param 1] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { +parameter.0 = f32[] parameter(0) +parameter.1 = f32[] parameter(1) +add.0 = f32[] add(parameter.0, parameter.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + // Create left module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | sub_0 | + // [Const 1] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { +constant.0 = f32[] constant(0) +constant.1 = f32[] constant(1) +subtract.0 = f32[] subtract(constant.0, constant.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + std::unique_ptr diff_result = + ConstructDiffResult(*graph_l, *graph_r, *mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*module_l, *module_r, *diff_result); + EXPECT_THAT(diff_summary->computation_diff_patterns, + UnorderedElementsAre( + FieldsAre( + /*fingerprint=*/619838372110990418U, + /*computation_groups=*/ + UnorderedElementsAre(FieldsAre( + /*left_computations=*/UnorderedElementsAre(Pointee( + Property(&HloComputation::name, "entry"))), + /*right_computations=*/IsEmpty())), + /*diff_metrics=*/ + FieldsAre(/*changed_instruction_count=*/0, + /*left_unmatched_instruction_count=*/3, + /*right_unmatched_instruction_count=*/0)), + FieldsAre( + /*fingerprint=*/591642684880638740U, + /*computation_groups=*/ + UnorderedElementsAre(FieldsAre( + /*left_computations=*/IsEmpty(), + /*right_computations=*/UnorderedElementsAre(Pointee( + Property(&HloComputation::name, "entry"))))), + /*diff_metrics=*/ + FieldsAre(/*changed_instruction_count=*/0, + /*left_unmatched_instruction_count=*/0, + /*right_unmatched_instruction_count=*/3)))); +} + MATCHER(EqualsComputationGroup, "") { const ComputationGroup& a = std::get<0>(arg); const ComputationGroup& b = std::get<1>(arg);