diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 400e1c017b2..78b55237d1e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -31,6 +31,12 @@ namespace tensorflow { namespace grappler { namespace { +static bool IsInvolution(const NodeDef& node) { + const std::unordered_set involution_ops = {"Conj", "Reciprocal", + "Neg", "LogicalNot"}; + return involution_ops.count(node.op()) > 0; +} + bool AreInversePermutations(gtl::ArraySlice a, gtl::ArraySlice b) { if (a.size() != b.size()) { @@ -394,10 +400,20 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const { string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, GraphDef* graph_def, NodeMap* node_map, std::vector* new_nodes) const { - // Remove inverse transposes. - if (node->op() == "Transpose") { + // Remove involutions applied twice. + if (IsInvolution(*node)) { + // An involution is a function f(x) that is its own inverse, + // i.e. f(f(x)) = x. const NodeDef* input = node_map->GetNode(node->input(0)); - if (input->op() == "Transpose") { + if (input->op() == node->op()) { + return input->input(0); + } + } + + // Remove inverse transposes. + if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") { + const NodeDef* input = node_map->GetNode(node->input(0)); + if (input->op() == node->op()) { const NodeDef* node_perm = node_map->GetNode(node->input(1)); const NodeDef* input_perm = node_map->GetNode(input->input(1)); std::vector node_perm_values; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 8edb34975f9..61c8b82ea0f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -109,6 +109,28 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { EXPECT_EQ("add1", new_add3.input(1)); } +TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + Output neg1 = ops::Neg(s.WithOpName("neg1"), c); + Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1); + Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); + Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); + Output id = ops::Identity(s.WithOpName("id"), recip2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(6, output.node_size()); + EXPECT_EQ("c", output.node(1).input(0)); + EXPECT_EQ("c", output.node(3).input(0)); + EXPECT_EQ("c", output.node(5).input(0)); +} + TEST_F(ArithmeticOptimizerTest, IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs =