#HLODiff Add unmatched computations to summary->computation_diff_patterns

PiperOrigin-RevId: 775770118
This commit is contained in:
A. Unique TensorFlower
2025-06-25 11:27:36 -07:00
committed by TensorFlower Gardener
parent 35a1ffb49c
commit ce60a770af
2 changed files with 78 additions and 0 deletions

View File

@@ -186,15 +186,23 @@ FindConnectedComponents(
absl::flat_hash_map<const HloComputation*, const ComputationSummary>
computation_summary) {
ConnectedComponentsFinder cc;
std::vector<std::vector<const HloComputation*>> unmatched_computations;
absl::flat_hash_map<uint64_t, std::vector<ComputationGroup>> result;
for (const auto& [computation, computation_match_info] :
computation_summary) {
if (computation_match_info.main_matched_computation != nullptr) {
cc.AddEdge(computation, computation_match_info.main_matched_computation);
} else {
// main_matched_computation is nullptr means all instructions in the
// computation are unmatched.
unmatched_computations.push_back({computation});
}
}
std::vector<std::vector<const HloComputation*>> connected_component_groups =
cc.FindConnectedComponents();
connected_component_groups.insert(connected_component_groups.end(),
unmatched_computations.begin(),
unmatched_computations.end());
for (const auto& component_group : connected_component_groups) {
bool all_unchanged = true;

View File

@@ -37,6 +37,7 @@ namespace {
using ::testing::ExplainMatchResult;
using ::testing::FieldsAre;
using ::testing::IsEmpty;
using ::testing::Pair;
using ::testing::Pointee;
using ::testing::Property;
@@ -375,6 +376,75 @@ TEST_F(HloDiffTest, FindConnectedComponentsWorks) {
/*right_unmatched_instruction_count=*/6))));
}
TEST_F(HloDiffTest, FindConnectedComponentsWorksForIsolatedComputations) {
// Create left module with entry computation containing the following
// structure:
// [Param 0] ---> ┌-------┐
// | add_0 |
// [Param 1] ---> └-------┘
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module_l,
ParseAndReturnVerifiedModule(R"(
HloModule module, is_scheduled=true
ENTRY entry {
parameter.0 = f32[] parameter(0)
parameter.1 = f32[] parameter(1)
add.0 = f32[] add(parameter.0, parameter.1)
}
)"));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<const HloGumgraph> graph_l,
HloGumgraph::Create(module_l.get()));
// Create left module with entry computation containing the following
// structure:
// [Const 0] ---> ┌-------┐
// | sub_0 |
// [Const 1] ---> └-------┘
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module_r,
ParseAndReturnVerifiedModule(R"(
HloModule module, is_scheduled=true
ENTRY entry {
constant.0 = f32[] constant(0)
constant.1 = f32[] constant(1)
subtract.0 = f32[] subtract(constant.0, constant.1)
}
)"));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<const HloGumgraph> graph_r,
HloGumgraph::Create(module_r.get()));
auto mappings = std::make_unique<HloGumgraphMappings>();
// Root nodes are matched by default before the matcher is called.
mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(),
MatcherType::kManual);
std::unique_ptr<const DiffResult> diff_result =
ConstructDiffResult(*graph_l, *graph_r, *mappings);
std::unique_ptr<const DiffSummary> diff_summary =
ConstructDiffSummary(*module_l, *module_r, *diff_result);
EXPECT_THAT(diff_summary->computation_diff_patterns,
UnorderedElementsAre(
FieldsAre(
/*fingerprint=*/619838372110990418U,
/*computation_groups=*/
UnorderedElementsAre(FieldsAre(
/*left_computations=*/UnorderedElementsAre(Pointee(
Property(&HloComputation::name, "entry"))),
/*right_computations=*/IsEmpty())),
/*diff_metrics=*/
FieldsAre(/*changed_instruction_count=*/0,
/*left_unmatched_instruction_count=*/3,
/*right_unmatched_instruction_count=*/0)),
FieldsAre(
/*fingerprint=*/591642684880638740U,
/*computation_groups=*/
UnorderedElementsAre(FieldsAre(
/*left_computations=*/IsEmpty(),
/*right_computations=*/UnorderedElementsAre(Pointee(
Property(&HloComputation::name, "entry"))))),
/*diff_metrics=*/
FieldsAre(/*changed_instruction_count=*/0,
/*left_unmatched_instruction_count=*/0,
/*right_unmatched_instruction_count=*/3))));
}
MATCHER(EqualsComputationGroup, "") {
const ComputationGroup& a = std::get<0>(arg);
const ComputationGroup& b = std::get<1>(arg);