diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index c9ff269e0c9..98e5a026a32 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/debug_options_flags.h" @@ -46,15 +48,9 @@ GTEST_API_ int main(int argc, char** argv) { } pattern = argv[i + 1]; } - // Unfortunately Google's internal benchmark infrastructure has a - // different API than Tensorflow's. + ::benchmark::Initialize(&argc, argv); testing::InitGoogleTest(&argc, argv); -#if defined(PLATFORM_GOOGLE) - absl::SetFlag(&FLAGS_benchmarks, pattern); - RunSpecifiedBenchmarks(); -#else - tensorflow::testing::Benchmark::Run(pattern); -#endif + benchmark::RunSpecifiedBenchmarks(); return 0; } } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bfd6c38496c..86b85a2f50b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -403,6 +403,7 @@ cc_library( ":protos_all_cc", "//tensorflow/core/platform/default/build_config:gtest", "//tensorflow/core/kernels:required", + "@com_google_benchmark//:benchmark", "@com_google_googletest//:gtest", ] + tf_additional_test_deps(), ) diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index b5cf34db2f5..0b86a85dfa7 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -60,9 +59,8 @@ Benchmark::Benchmark(const string& device, Graph* g, options = &default_options; } - old_benchmark_api_ = old_benchmark_api; CHECK(!old_benchmark_api) << "Expected new API only"; - if (old_benchmark_api_) testing::StopTiming(); + old_benchmark_api_ = false; string t = absl::AsciiStrToUpper(device); // Allow NewDevice to allocate a new threadpool with different number of // threads for each new benchmark. @@ -139,8 +137,7 @@ Benchmark::~Benchmark() { } } - -void Benchmark::Run(::testing::benchmark::State& state) { +void Benchmark::Run(benchmark::State& state) { RunWithRendezvousArgs({}, {}, state); } @@ -160,7 +157,7 @@ string GetRendezvousKey(const Node* node) { void Benchmark::RunWithRendezvousArgs( const std::vector>& inputs, - const std::vector& outputs, ::testing::benchmark::State& state) { + const std::vector& outputs, benchmark::State& state) { CHECK(!old_benchmark_api_) << "This method should only be called with new benchmark API"; if (!device_ || state.max_iterations == 0) { diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h index fd0d36499dd..89aec6b2057 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -24,14 +24,9 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -namespace testing { -namespace benchmark { -class State; -} // namespace benchmark -} // namespace testing - namespace tensorflow { class Device; @@ -62,11 +57,11 @@ class Benchmark { ~Benchmark(); - void Run(::testing::benchmark::State& state); + void Run(benchmark::State& state); void RunWithRendezvousArgs( const std::vector>& inputs, - const std::vector& outputs, ::testing::benchmark::State& state); + const std::vector& outputs, benchmark::State& state); private: thread::ThreadPool* pool_ = nullptr; // Not owned. diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 90400f85c1d..fc0cd1b2326 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -826,6 +826,7 @@ cc_library( testonly = True, hdrs = ["test_benchmark.h"], deps = [ + "@com_google_benchmark//:benchmark", ":platform", ] + tf_platform_deps("test_benchmark"), ) @@ -1063,7 +1064,6 @@ cc_library( ":stacktrace_handler", ":test", ":test_benchmark", - "//tensorflow/core/platform/default/build_config:test_main", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h index 51731815128..0001ddc7cb2 100644 --- a/tensorflow/core/platform/test_benchmark.h +++ b/tensorflow/core/platform/test_benchmark.h @@ -17,13 +17,38 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ #define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ +#include "benchmark/benchmark.h" // from @com_google_benchmark // IWYU pragma: export #include "tensorflow/core/platform/platform.h" -#if defined(PLATFORM_GOOGLE) -#include "tensorflow/core/platform/google/test_benchmark.h" // IWYU pragma: export -#else -#include "tensorflow/core/platform/default/test_benchmark.h" // IWYU pragma: export -#endif // PLATFORM_GOOGLE +// FIXME(vyng): Remove this. +// Background: During the benchmark-migration projects, all benchmarks were made +// to use "testing::benchmark::" prefix because that is what the internal +// Google benchmark library use. +namespace testing { +namespace benchmark { +using ::benchmark::State; // NOLINT +} // namespace benchmark +} // namespace testing +namespace tensorflow { +namespace testing { + +namespace internal { +void UseCharPointer(char const volatile*); +} + +inline void RunBenchmarks() { benchmark::RunSpecifiedBenchmarks(); } + +template +void DoNotOptimize(const T& var) { +#if defined(_MSC_VER) + internal::UseCharPointer(reinterpret_cast(&var)); + _ReadWriteBarrier(); +#else + asm volatile("" : "+m"(const_cast(var))); +#endif +} +} // namespace testing +} // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc index eb89d5eba79..7171f737fc8 100644 --- a/tensorflow/core/platform/test_main.cc +++ b/tensorflow/core/platform/test_main.cc @@ -18,15 +18,11 @@ limitations under the License. // the --benchmark_filter flag which specifies which benchmarks to run, // we will either run benchmarks or run the gtest tests in the program. -#include "tensorflow/core/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__) -// main() is supplied by gunit_main -#else - #include +#include #include "absl/strings/match.h" +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/stacktrace_handler.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -35,14 +31,23 @@ GTEST_API_ int main(int argc, char** argv) { std::cout << "Running main() from test_main.cc\n"; tensorflow::testing::InstallStacktraceHandler(); - testing::InitGoogleTest(&argc, argv); + for (int i = 1; i < argc; i++) { - if (absl::StartsWith(argv[i], "--benchmarks=")) { - const char* pattern = argv[i] + strlen("--benchmarks="); - tensorflow::testing::Benchmark::Run(pattern); + if (absl::StartsWith(argv[i], "--benchmark_filter=")) { + ::benchmark::Initialize(&argc, argv); + + // XXX: Must be called after benchmark's init because + // InitGoogleTest eventually calls absl::ParseCommandLine() which would + // complain that benchmark_filter flag is not known because that flag is + // defined by the benchmark library via its own command-line flag + // facility, which is not known to absl flags. + // FIXME(vyng): Fix this mess once we make benchmark use absl flags + testing::InitGoogleTest(&argc, argv); + ::benchmark::RunSpecifiedBenchmarks(); return 0; } } + + testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } -#endif diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 419023461ee..b00d5e5aac8 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -19,6 +19,7 @@ load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", "def_fi load("//third_party/FP16:workspace.bzl", FP16 = "repo") load("//third_party/absl:workspace.bzl", absl = "repo") load("//third_party/aws:workspace.bzl", aws = "repo") +load("//third_party/benchmark:workspace.bzl", benchmark = "repo") load("//third_party/clog:workspace.bzl", clog = "repo") load("//third_party/cpuinfo:workspace.bzl", cpuinfo = "repo") load("//third_party/dlpack:workspace.bzl", dlpack = "repo") @@ -55,6 +56,7 @@ def _initialize_third_party(): FP16() absl() aws() + benchmark() clog() cpuinfo() dlpack() diff --git a/third_party/benchmark/BUILD b/third_party/benchmark/BUILD new file mode 100644 index 00000000000..e69de29bb2d diff --git a/third_party/benchmark/BUILD.bazel b/third_party/benchmark/BUILD.bazel new file mode 100644 index 00000000000..88f17fd6e6f --- /dev/null +++ b/third_party/benchmark/BUILD.bazel @@ -0,0 +1,55 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache + +exports_files(["LICENSE"]) + +config_setting( + name = "qnx", + constraint_values = ["@platforms//os:qnx"], + values = { + "cpu": "x64_qnx", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], + values = { + "cpu": "x64_windows", + }, + visibility = [":__subpackages__"], +) + +cc_library( + name = "benchmark", + srcs = glob( + [ + "src/*.cc", + "src/*.h", + ], + exclude = ["src/benchmark_main.cc"], + ), + hdrs = ["include/benchmark/benchmark.h"], + linkopts = select({ + ":windows": ["-DEFAULTLIB:shlwapi.lib"], + "//conditions:default": ["-pthread"], + }), + strip_include_prefix = "include", + visibility = ["//visibility:public"], +) + +cc_library( + name = "benchmark_main", + srcs = ["src/benchmark_main.cc"], + hdrs = ["include/benchmark/benchmark.h"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = [":benchmark"], +) + +cc_library( + name = "benchmark_internal_headers", + hdrs = glob(["src/*.h"]), +) diff --git a/third_party/benchmark/workspace.bzl b/third_party/benchmark/workspace.bzl new file mode 100644 index 00000000000..dde5964b927 --- /dev/null +++ b/third_party/benchmark/workspace.bzl @@ -0,0 +1,18 @@ +"""Provides the repo macro to import google benchmark""" + +load("//third_party:repo.bzl", "tf_http_archive") + +def repo(): + """Imports benchmark.""" + BM_COMMIT = "64cb55e91067860548cb95e012a38f2e5b71e026" + BM_SHA256 = "480bb4f1ffa402e5782a20dc8986f5c86b87c497195dc53c9067e502ff45ef57" + tf_http_archive( + name = "com_google_benchmark", + sha256 = BM_SHA256, + strip_prefix = "benchmark-{commit}".format(commit = BM_COMMIT), + build_file = "//third_party/benchmark:BUILD.bazel", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/benchmark/archive/{commit}.tar.gz".format(commit = BM_COMMIT), + "https://github.com/google/benchmark/archive/{commit}.tar.gz".format(commit = BM_COMMIT), + ], + )