mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Automated g4 rollback of changelist 165630063
PiperOrigin-RevId: 165646100
This commit is contained in:
committed by
TensorFlower Gardener
parent
0fab1a5d39
commit
a92bd5d5cb
@@ -14,7 +14,7 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
load(":build_defs.bzl", "runtime_copts", "runtime_logging_deps")
|
||||
load(":build_defs.bzl", "runtime_copts")
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
@@ -381,16 +381,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_matvec",
|
||||
srcs = ["runtime_matvec.cc"],
|
||||
hdrs = ["runtime_matvec.h"],
|
||||
copts = runtime_copts(),
|
||||
deps = [
|
||||
"//third_party/eigen3",
|
||||
] + runtime_logging_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_matmul",
|
||||
srcs = ["runtime_matmul.cc"],
|
||||
@@ -398,7 +388,6 @@ cc_library(
|
||||
copts = runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":runtime_matvec",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
@@ -428,7 +417,6 @@ cc_library(
|
||||
copts = runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":runtime_matvec",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@@ -440,7 +428,6 @@ cc_test(
|
||||
deps = [
|
||||
":cpu_runtime",
|
||||
":runtime_matmul",
|
||||
":runtime_single_threaded_matmul",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
||||
@@ -1,25 +1,11 @@
|
||||
"""build_defs for service/cpu."""
|
||||
|
||||
|
||||
def runtime_copts():
|
||||
"""Returns copts used for CPU runtime libraries."""
|
||||
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
|
||||
"//tensorflow:android_arm": ["-mfpu=neon"],
|
||||
"//conditions:default": []
|
||||
}) + select({
|
||||
"//tensorflow:android": ["-O2"],
|
||||
"//conditions:default": []
|
||||
}))
|
||||
|
||||
|
||||
def runtime_logging_deps():
|
||||
"""Returns deps for building CPU runtime libraries with logging functions."""
|
||||
return select({
|
||||
"//tensorflow:android": [
|
||||
# This dependency is smaller than :android_tensorflow_lib
|
||||
"//tensorflow/core:android_tensorflow_lib_selective_registration",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
})
|
||||
return (["-DEIGEN_AVOID_STL_ARRAY"] +
|
||||
select({
|
||||
"//tensorflow:android_arm": ["-mfpu=neon"],
|
||||
"//conditions:default": []}) +
|
||||
select({
|
||||
"//tensorflow:android": ["-O2"],
|
||||
"//conditions:default": []}))
|
||||
|
||||
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
@@ -26,10 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@@ -78,8 +75,14 @@ void CheckMatrixMultiply(const Array2D<float>& a, const Array2D<float>& b,
|
||||
std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
|
||||
const Array2D<float>& b,
|
||||
bool transpose_lhs,
|
||||
bool transpose_rhs,
|
||||
bool single_threaded) {
|
||||
bool transpose_rhs) {
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
2);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_intra_op_thread_pool(&device);
|
||||
|
||||
CHECK_EQ(a.width(), b.height());
|
||||
int64 m = a.height();
|
||||
int64 n = b.width();
|
||||
@@ -95,81 +98,41 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
|
||||
// Since we're going to transpose c before returning it. Swap the order of the
|
||||
// dimension sizes to ensure the returned array is properly dimensioned.
|
||||
auto c_transpose = MakeUnique<Array2D<float>>(n, m);
|
||||
if (single_threaded) {
|
||||
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(
|
||||
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
|
||||
m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
2);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_intra_op_thread_pool(&device);
|
||||
|
||||
__xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
|
||||
a_transpose->data(), b_transpose->data(),
|
||||
m, n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
__xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
|
||||
a_transpose->data(), b_transpose->data(), m,
|
||||
n, k, transpose_lhs, transpose_rhs);
|
||||
return MaybeTransposeArray2D(*c_transpose, true);
|
||||
}
|
||||
|
||||
struct MatMulShape {
|
||||
int64 m;
|
||||
int64 k;
|
||||
int64 n;
|
||||
};
|
||||
TEST_F(CpuRuntimeTest, SmallEigenMatmul) {
|
||||
Array2D<float> a({{1.0f, 2.0f}, {3.0f, 4.0f}});
|
||||
Array2D<float> b({{5.0f, -1.0f, 3.0f}, {2.0f, 6.0f, 4.0f}});
|
||||
|
||||
MatMulShape MatMulShapes[] = {
|
||||
MatMulShape{2, 2, 3}, MatMulShape{256, 512, 1024},
|
||||
MatMulShape{128, 128, 1}, MatMulShape{1, 128, 128},
|
||||
MatMulShape{1, 32, 128}, MatMulShape{1, 32, 16},
|
||||
MatMulShape{32, 16, 1}, MatMulShape{32, 128, 1},
|
||||
};
|
||||
for (bool transpose_lhs : {false, true}) {
|
||||
for (bool transpose_rhs : {false, true}) {
|
||||
auto c = EigenMatrixMultiply(a, b, transpose_lhs, transpose_rhs);
|
||||
|
||||
// This takes 4 parameters:
|
||||
// * shape of the matmul
|
||||
// * transpose_lhs
|
||||
// * transpose_rhs
|
||||
// * single_threaded
|
||||
using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
|
||||
LOG(INFO) << "a = " << a.ToString();
|
||||
LOG(INFO) << "b = " << b.ToString();
|
||||
LOG(INFO) << "c = " << c->ToString();
|
||||
|
||||
class EigenMatMulTest
|
||||
: public CpuRuntimeTest,
|
||||
public ::testing::WithParamInterface<EigenMatMulTestParam> {
|
||||
public:
|
||||
static string Name(
|
||||
const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
|
||||
MatMulShape shape = std::get<0>(info.param);
|
||||
bool transpose_lhs = std::get<1>(info.param);
|
||||
bool transpose_rhs = std::get<2>(info.param);
|
||||
bool single_threaded = std::get<3>(info.param);
|
||||
|
||||
return tensorflow::strings::Printf(
|
||||
"MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
|
||||
transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
|
||||
single_threaded ? "single" : "multi");
|
||||
CheckMatrixMultiply(a, b, *c);
|
||||
}
|
||||
}
|
||||
}; // namespace xla
|
||||
|
||||
TEST_P(EigenMatMulTest, DoIt) {
|
||||
MatMulShape shape = std::get<0>(GetParam());
|
||||
bool transpose_lhs = std::get<1>(GetParam());
|
||||
bool transpose_rhs = std::get<2>(GetParam());
|
||||
bool single_threaded = std::get<3>(GetParam());
|
||||
|
||||
auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
|
||||
auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
|
||||
auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs,
|
||||
single_threaded);
|
||||
CheckMatrixMultiply(*a, *b, *c);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
|
||||
::testing::Combine(::testing::ValuesIn(MatMulShapes),
|
||||
::testing::Bool(), ::testing::Bool(),
|
||||
::testing::Bool()),
|
||||
EigenMatMulTest::Name);
|
||||
TEST_F(CpuRuntimeTest, LargeEigenMatmul) {
|
||||
auto a = MakeLinspaceArray2D(0.0, 1.0, 256, 512);
|
||||
auto b = MakeLinspaceArray2D(-2.0, 2.0, 512, 1024);
|
||||
|
||||
for (bool transpose_lhs : {false, true}) {
|
||||
for (bool transpose_rhs : {false, true}) {
|
||||
auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs);
|
||||
|
||||
CheckMatrixMultiply(*a, *b, *c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::int32;
|
||||
@@ -69,24 +68,14 @@ void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
|
||||
float* lhs, float* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
if (m == 1 || n == 1) {
|
||||
// Despite being single threaded, this version of matrix * vector is faster.
|
||||
xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
|
||||
double* lhs, double* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs,
|
||||
int32 transpose_rhs) {
|
||||
if (m == 1 || n == 1) {
|
||||
// Despite being single threaded, this version of matrix * vector is faster.
|
||||
xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using tensorflow::int32;
|
||||
using tensorflow::int64;
|
||||
|
||||
namespace {
|
||||
|
||||
// Does mat * x or mat^T * x.
|
||||
template <typename T>
|
||||
void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols,
|
||||
int32 transpose) {
|
||||
// Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to
|
||||
// be faster (b/30223679). See also: the matmul op kernel in TensorFlow,
|
||||
// which implements the same optimization.
|
||||
using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
|
||||
using Vector = Eigen::Matrix<T, Eigen::Dynamic, 1>;
|
||||
using VectorMap = Eigen::Map<Vector>;
|
||||
|
||||
auto x = VectorMap(x_buf, cols);
|
||||
auto out = VectorMap(out_buf, rows);
|
||||
|
||||
int64 mat_rows = rows;
|
||||
int64 mat_cols = cols;
|
||||
|
||||
if (transpose) {
|
||||
std::swap(mat_rows, mat_cols);
|
||||
}
|
||||
|
||||
auto mat = MatrixMap(mat_buf, mat_rows, mat_cols);
|
||||
|
||||
if (transpose) {
|
||||
out = mat.transpose() * x;
|
||||
} else {
|
||||
out = mat * x;
|
||||
}
|
||||
}
|
||||
|
||||
// Converts matmul-style args to matvec.
|
||||
template <typename T>
|
||||
void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k,
|
||||
int32 transpose_lhs, int32 transpose_rhs) {
|
||||
// If the input is in the form x * A, where x is the vector, then bring A back
|
||||
// over to the left hand side. We make use of the identity
|
||||
//
|
||||
// (x * A)^T = A^T * x^T
|
||||
//
|
||||
// We do not need to take the transpose of x or of the result since taking
|
||||
// the transpose of a vector does not change the memory layout.
|
||||
const int64 cols = k;
|
||||
|
||||
T* mat;
|
||||
T* vec;
|
||||
int64 rows;
|
||||
bool transpose_mat;
|
||||
|
||||
bool is_mat_vec = (n == 1);
|
||||
|
||||
if (is_mat_vec) {
|
||||
mat = lhs;
|
||||
vec = rhs;
|
||||
rows = m;
|
||||
transpose_mat = transpose_lhs;
|
||||
} else {
|
||||
mat = rhs;
|
||||
vec = lhs;
|
||||
rows = n;
|
||||
transpose_mat = !transpose_rhs;
|
||||
}
|
||||
|
||||
MatVec<T>(out, mat, vec, rows, cols, transpose_mat);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace xla {
|
||||
|
||||
void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
|
||||
DispatchMatVec<float>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n,
|
||||
int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
|
||||
DispatchMatVec<double>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
@@ -1,45 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are
|
||||
// pointers to buffers containing input matrices in column-major order. 'out' is
|
||||
// a pointer to a buffer sufficiently large to hold the result of the
|
||||
// operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and
|
||||
// out is m x n.
|
||||
//
|
||||
// This requires that m = 1 or n = 1.
|
||||
//
|
||||
// TODO(b/64684907): Compare runtime performance of these functions with dot
|
||||
// simplification.
|
||||
void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m,
|
||||
tensorflow::int64 n, tensorflow::int64 k,
|
||||
tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m,
|
||||
tensorflow::int64 n, tensorflow::int64 k,
|
||||
tensorflow::int32 transpose_lhs,
|
||||
tensorflow::int32 transpose_rhs);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
|
||||
@@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::int32;
|
||||
@@ -62,21 +61,13 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
|
||||
void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
|
||||
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
if (m == 1 || n == 1) {
|
||||
xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
|
||||
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
|
||||
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
|
||||
if (m == 1 || n == 1) {
|
||||
xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
|
||||
} else {
|
||||
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
|
||||
transpose_rhs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user