Automated g4 rollback of changelist 165630063

PiperOrigin-RevId: 165646100
This commit is contained in:
A. Unique TensorFlower
2017-08-17 16:31:35 -07:00
committed by TensorFlower Gardener
parent 0fab1a5d39
commit a92bd5d5cb
7 changed files with 50 additions and 289 deletions

View File

@@ -14,7 +14,7 @@ package_group(
],
)
load(":build_defs.bzl", "runtime_copts", "runtime_logging_deps")
load(":build_defs.bzl", "runtime_copts")
# Filegroup used to collect source files for dependency checking.
filegroup(
@@ -381,16 +381,6 @@ cc_library(
],
)
cc_library(
name = "runtime_matvec",
srcs = ["runtime_matvec.cc"],
hdrs = ["runtime_matvec.h"],
copts = runtime_copts(),
deps = [
"//third_party/eigen3",
] + runtime_logging_deps(),
)
cc_library(
name = "runtime_matmul",
srcs = ["runtime_matmul.cc"],
@@ -398,7 +388,6 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
":runtime_matvec",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
@@ -428,7 +417,6 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
":runtime_matvec",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
@@ -440,7 +428,6 @@ cc_test(
deps = [
":cpu_runtime",
":runtime_matmul",
":runtime_single_threaded_matmul",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",

View File

@@ -1,25 +1,11 @@
"""build_defs for service/cpu."""
def runtime_copts():
"""Returns copts used for CPU runtime libraries."""
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
"//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": []
}) + select({
"//tensorflow:android": ["-O2"],
"//conditions:default": []
}))
def runtime_logging_deps():
"""Returns deps for building CPU runtime libraries with logging functions."""
return select({
"//tensorflow:android": [
# This dependency is smaller than :android_tensorflow_lib
"//tensorflow/core:android_tensorflow_lib_selective_registration",
],
"//conditions:default": [
"//tensorflow/core:lib",
],
})
return (["-DEIGEN_AVOID_STL_ARRAY"] +
select({
"//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": []}) +
select({
"//tensorflow:android": ["-O2"],
"//conditions:default": []}))

View File

@@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include <string>
#include <tuple>
#define EIGEN_USE_THREADS
@@ -26,10 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -78,8 +75,14 @@ void CheckMatrixMultiply(const Array2D<float>& a, const Array2D<float>& b,
std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
const Array2D<float>& b,
bool transpose_lhs,
bool transpose_rhs,
bool single_threaded) {
bool transpose_rhs) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
2);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
ExecutableRunOptions run_options;
run_options.set_intra_op_thread_pool(&device);
CHECK_EQ(a.width(), b.height());
int64 m = a.height();
int64 n = b.width();
@@ -95,81 +98,41 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it. Swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
auto c_transpose = MakeUnique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
m, n, k, transpose_lhs, transpose_rhs);
} else {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
2);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
ExecutableRunOptions run_options;
run_options.set_intra_op_thread_pool(&device);
__xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
a_transpose->data(), b_transpose->data(),
m, n, k, transpose_lhs, transpose_rhs);
}
__xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
a_transpose->data(), b_transpose->data(), m,
n, k, transpose_lhs, transpose_rhs);
return MaybeTransposeArray2D(*c_transpose, true);
}
struct MatMulShape {
int64 m;
int64 k;
int64 n;
};
TEST_F(CpuRuntimeTest, SmallEigenMatmul) {
Array2D<float> a({{1.0f, 2.0f}, {3.0f, 4.0f}});
Array2D<float> b({{5.0f, -1.0f, 3.0f}, {2.0f, 6.0f, 4.0f}});
MatMulShape MatMulShapes[] = {
MatMulShape{2, 2, 3}, MatMulShape{256, 512, 1024},
MatMulShape{128, 128, 1}, MatMulShape{1, 128, 128},
MatMulShape{1, 32, 128}, MatMulShape{1, 32, 16},
MatMulShape{32, 16, 1}, MatMulShape{32, 128, 1},
};
for (bool transpose_lhs : {false, true}) {
for (bool transpose_rhs : {false, true}) {
auto c = EigenMatrixMultiply(a, b, transpose_lhs, transpose_rhs);
// This takes 4 parameters:
// * shape of the matmul
// * transpose_lhs
// * transpose_rhs
// * single_threaded
using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
LOG(INFO) << "a = " << a.ToString();
LOG(INFO) << "b = " << b.ToString();
LOG(INFO) << "c = " << c->ToString();
class EigenMatMulTest
: public CpuRuntimeTest,
public ::testing::WithParamInterface<EigenMatMulTestParam> {
public:
static string Name(
const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
MatMulShape shape = std::get<0>(info.param);
bool transpose_lhs = std::get<1>(info.param);
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
return tensorflow::strings::Printf(
"MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
single_threaded ? "single" : "multi");
CheckMatrixMultiply(a, b, *c);
}
}
}; // namespace xla
TEST_P(EigenMatMulTest, DoIt) {
MatMulShape shape = std::get<0>(GetParam());
bool transpose_lhs = std::get<1>(GetParam());
bool transpose_rhs = std::get<2>(GetParam());
bool single_threaded = std::get<3>(GetParam());
auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs,
single_threaded);
CheckMatrixMultiply(*a, *b, *c);
}
INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
::testing::Combine(::testing::ValuesIn(MatMulShapes),
::testing::Bool(), ::testing::Bool(),
::testing::Bool()),
EigenMatMulTest::Name);
TEST_F(CpuRuntimeTest, LargeEigenMatmul) {
auto a = MakeLinspaceArray2D(0.0, 1.0, 256, 512);
auto b = MakeLinspaceArray2D(-2.0, 2.0, 512, 1024);
for (bool transpose_lhs : {false, true}) {
for (bool transpose_rhs : {false, true}) {
auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs);
CheckMatrixMultiply(*a, *b, *c);
}
}
}
} // namespace
} // namespace xla

View File

@@ -19,7 +19,6 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -69,24 +68,14 @@ void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
float* lhs, float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
if (m == 1 || n == 1) {
// Despite being single threaded, this version of matrix * vector is faster.
xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
} else {
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
if (m == 1 || n == 1) {
// Despite being single threaded, this version of matrix * vector is faster.
xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
} else {
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}

View File

@@ -1,110 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::int32;
using tensorflow::int64;
namespace {
// Does mat * x or mat^T * x.
template <typename T>
void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols,
int32 transpose) {
// Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to
// be faster (b/30223679). See also: the matmul op kernel in TensorFlow,
// which implements the same optimization.
using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
using MatrixMap = Eigen::Map<Matrix>;
using Vector = Eigen::Matrix<T, Eigen::Dynamic, 1>;
using VectorMap = Eigen::Map<Vector>;
auto x = VectorMap(x_buf, cols);
auto out = VectorMap(out_buf, rows);
int64 mat_rows = rows;
int64 mat_cols = cols;
if (transpose) {
std::swap(mat_rows, mat_cols);
}
auto mat = MatrixMap(mat_buf, mat_rows, mat_cols);
if (transpose) {
out = mat.transpose() * x;
} else {
out = mat * x;
}
}
// Converts matmul-style args to matvec.
template <typename T>
void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k,
int32 transpose_lhs, int32 transpose_rhs) {
// If the input is in the form x * A, where x is the vector, then bring A back
// over to the left hand side. We make use of the identity
//
// (x * A)^T = A^T * x^T
//
// We do not need to take the transpose of x or of the result since taking
// the transpose of a vector does not change the memory layout.
const int64 cols = k;
T* mat;
T* vec;
int64 rows;
bool transpose_mat;
bool is_mat_vec = (n == 1);
if (is_mat_vec) {
mat = lhs;
vec = rhs;
rows = m;
transpose_mat = transpose_lhs;
} else {
mat = rhs;
vec = lhs;
rows = n;
transpose_mat = !transpose_rhs;
}
MatVec<T>(out, mat, vec, rows, cols, transpose_mat);
}
} // namespace
namespace xla {
void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
DispatchMatVec<float>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
}
void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
DispatchMatVec<double>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
}
} // namespace xla

View File

@@ -1,45 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
#include "tensorflow/core/platform/types.h"
namespace xla {
// Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are
// pointers to buffers containing input matrices in column-major order. 'out' is
// a pointer to a buffer sufficiently large to hold the result of the
// operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and
// out is m x n.
//
// This requires that m = 1 or n = 1.
//
// TODO(b/64684907): Compare runtime performance of these functions with dot
// simplification.
void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m,
tensorflow::int64 n, tensorflow::int64 k,
tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m,
tensorflow::int64 n, tensorflow::int64 k,
tensorflow::int32 transpose_lhs,
tensorflow::int32 transpose_rhs);
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_

View File

@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -62,21 +61,13 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
if (m == 1 || n == 1) {
xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
} else {
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
if (m == 1 || n == 1) {
xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
} else {
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}