diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index ed23c8c2dd7..e000a067061 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -360,6 +360,21 @@ class HloDotDumper { string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); + // For most instructions, GetNodeForEdge(instr) returns instr. + // + // The exception is fusion nodes. For these, we walk up the chain of nested + // fusion nodes starting at instr until we reach a node that either (a) isn't + // a fusion node, or (b) is a fusion node for which + // ShouldShowFusionSubcomputation is false. + // + // We do this because fusion nodes are expanded inline -- if + // ShouldShowFusionSubcomputation is true, the fusion node won't be present in + // the graph. + // + // In general when you want to draw an edge from A to B, you should actually + // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B). + const HloInstruction* GetNodeForEdge(const HloInstruction* instr); + // If instr has just one computation and it's trivial (e.g. "return param0 + // param1"), returns a string you can put into the node's body that names the // subcomputation, e.g. "Subcomputation: add". @@ -595,16 +610,15 @@ tooltip = " "; // belongs to a fusion node, it's drawn in place of the fusion instruction, // so there's no need to link those. if (parent_instr->opcode() != HloOpcode::kFusion) { - VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to " - << parent_instr->name() << " as " << next_edge_id_; - edge_ids_.insert( - {{subcomp->root_instruction(), parent_instr}, next_edge_id_++}); + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back( - Printf(edge_fmt, InstructionId(subcomp->root_instruction()), - InstructionId(parent_instr), SubcomputationId(subcomp), - subcomp->name(), parent_instr->name())); + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } string computation = @@ -633,15 +647,7 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { } string HloDotDumper::DumpRootTag() { - HloInstruction* from = computation_->root_instruction(); - - // Fusion nodes are expanded inline, so if root is an expanded fusion node, - // walk up the graph until we find a node that isn't. - while (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } - + const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); auto from_id = InstructionId(from); if (!filter_.Show(from)) { @@ -1080,13 +1086,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { - // Fusion nodes' subcomputations are displayed inline, so if 'from' is a - // fusion node and the node's subcomputation is shown, we draw our edge - // starting at the fusion node's root instead of at the fusion node itself. - if (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } + from = GetNodeForEdge(from); + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || ShouldMergeIntoUsers(from)) { return; @@ -1154,6 +1155,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr( return Join(lines, "
"); } +const HloInstruction* HloDotDumper::GetNodeForEdge( + const HloInstruction* instr) { + while (instr->opcode() == HloOpcode::kFusion && + ShouldShowFusionSubcomputation(instr)) { + instr = instr->fused_expression_root(); + } + return instr; +} + tensorflow::mutex& RendererMutex() { static tensorflow::mutex* mu = new tensorflow::mutex; return *mu;