diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 750e18a9ca0..bf8dcb2fcf2 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -909,8 +909,13 @@ void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) { // No known send_device. The runtime will detect it later. return; } - int64 incarnation = opts.get_incarnation(send_device); - AddNodeAttr("send_device_incarnation", incarnation, ndef); + int64 incarnation = PartitionOptions::kIllegalIncarnation; + if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() || + (incarnation == PartitionOptions::kIllegalIncarnation)) { + incarnation = opts.get_incarnation(send_device); + SetAttrValue(incarnation, + &((*ndef->mutable_attr())["send_device_incarnation"])); + } } // Sets attribute send_device_incarnation of all Send/Recv nodes in diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 9c49b0b67b1..3c12ed2689e 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -445,6 +445,41 @@ TEST_F(GraphPartitionTest, Functions) { ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"}); } +TEST_F(GraphPartitionTest, SetIncarnation) { + GraphDef gdef; + const char* const kSendRecvAttrs = R"proto( + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'client_terminated' value { b: false } } + attr { key: 'recv_device' value { s: 'B' } } + attr { key: 'send_device' value { s: 'A' } } + attr { key: 'send_device_incarnation' value { i: 0 } } + attr { key: 'tensor_name' value { s: 'test' } } +)proto"; + CHECK(protobuf::TextFormat::ParseFromString( + StrCat("node { name: 'A/Pi' op: 'Const' ", + " attr { key: 'dtype' value { type: DT_FLOAT } } ", + " attr { key: 'value' value { tensor { ", + " dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }", + "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}", + "node { name: 'B' op: '_Recv' ", kSendRecvAttrs, + " attr { key: 'tensor_type' value { type:DT_FLOAT}}}"), + &gdef)); + gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION); + Partition(gdef, &partitions_); + EXPECT_EQ(2, partitions_.size()); + + for (const auto& kv : partitions_) { + const GraphDef& gdef = kv.second; + for (const NodeDef& ndef : gdef.node()) { + if (ndef.name() == "A" || ndef.name() == "B") { + int64 val; + TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val)); + EXPECT_EQ(val, 100); // Send device is "A". + } + } + } +} + TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { // Create placeholders, shuffle them so the order in the graph is not strictly // increasing.