diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index ed23c8c2dd7..e000a067061 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -360,6 +360,21 @@ class HloDotDumper {
string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr);
+ // For most instructions, GetNodeForEdge(instr) returns instr.
+ //
+ // The exception is fusion nodes. For these, we walk up the chain of nested
+ // fusion nodes starting at instr until we reach a node that either (a) isn't
+ // a fusion node, or (b) is a fusion node for which
+ // ShouldShowFusionSubcomputation is false.
+ //
+ // We do this because fusion nodes are expanded inline -- if
+ // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
+ // the graph.
+ //
+ // In general when you want to draw an edge from A to B, you should actually
+ // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
+ const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
+
// If instr has just one computation and it's trivial (e.g. "return param0 +
// param1"), returns a string you can put into the node's body that names the
// subcomputation, e.g. "Subcomputation: add".
@@ -595,16 +610,15 @@ tooltip = " ";
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
- VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to "
- << parent_instr->name() << " as " << next_edge_id_;
- edge_ids_.insert(
- {{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
+ const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
+ VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
+ << " as " << next_edge_id_;
+ edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
const char* edge_fmt =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
- edges_.push_back(
- Printf(edge_fmt, InstructionId(subcomp->root_instruction()),
- InstructionId(parent_instr), SubcomputationId(subcomp),
- subcomp->name(), parent_instr->name()));
+ edges_.push_back(Printf(
+ edge_fmt, InstructionId(from), InstructionId(parent_instr),
+ SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
}
string computation =
@@ -633,15 +647,7 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
}
string HloDotDumper::DumpRootTag() {
- HloInstruction* from = computation_->root_instruction();
-
- // Fusion nodes are expanded inline, so if root is an expanded fusion node,
- // walk up the graph until we find a node that isn't.
- while (from->opcode() == HloOpcode::kFusion &&
- ShouldShowFusionSubcomputation(from)) {
- from = from->fused_expression_root();
- }
-
+ const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
auto from_id = InstructionId(from);
if (!filter_.Show(from)) {
@@ -1080,13 +1086,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
- // Fusion nodes' subcomputations are displayed inline, so if 'from' is a
- // fusion node and the node's subcomputation is shown, we draw our edge
- // starting at the fusion node's root instead of at the fusion node itself.
- if (from->opcode() == HloOpcode::kFusion &&
- ShouldShowFusionSubcomputation(from)) {
- from = from->fused_expression_root();
- }
+ from = GetNodeForEdge(from);
+
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
ShouldMergeIntoUsers(from)) {
return;
@@ -1154,6 +1155,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
return Join(lines, "
");
}
+const HloInstruction* HloDotDumper::GetNodeForEdge(
+ const HloInstruction* instr) {
+ while (instr->opcode() == HloOpcode::kFusion &&
+ ShouldShowFusionSubcomputation(instr)) {
+ instr = instr->fused_expression_root();
+ }
+ return instr;
+}
+
tensorflow::mutex& RendererMutex() {
static tensorflow::mutex* mu = new tensorflow::mutex;
return *mu;