mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Address redundant H2D/D2D/D2H transfers for int32 eager operations.
PiperOrigin-RevId: 452869929
This commit is contained in:
committed by
TensorFlower Gardener
parent
989ea1677c
commit
0416617eca
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
@@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@@ -290,7 +292,32 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
|
||||
return {x, tensorflow::FingerprintCat64(a.high64, x)};
|
||||
}
|
||||
|
||||
const KernelDef* GetKernelDef(const EagerOperation& op, const NodeDef* node_def,
|
||||
const Device* op_device) {
|
||||
if (node_def == nullptr || op_device == nullptr) return nullptr;
|
||||
const KernelDef* kernel_def = nullptr;
|
||||
Status s = FindKernelDef(DeviceType(op_device->device_type()), *node_def,
|
||||
&kernel_def,
|
||||
/*kernel_class_name=*/nullptr);
|
||||
if (!s.ok()) return nullptr;
|
||||
return kernel_def;
|
||||
}
|
||||
|
||||
bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def,
|
||||
const Device* op_device, const KernelDef* kernel_def,
|
||||
const int port_id) {
|
||||
if (op.is_function()) return false;
|
||||
if (node_def == nullptr) return false;
|
||||
if (kernel_def == nullptr || op_device == nullptr) return false;
|
||||
const auto& host_memory_args = kernel_def->host_memory_arg();
|
||||
const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def;
|
||||
const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id);
|
||||
return std::find(host_memory_args.begin(), host_memory_args.end(),
|
||||
op_def.input_arg(arg_id).name()) != host_memory_args.end();
|
||||
}
|
||||
|
||||
Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx,
|
||||
const bool is_host_memory_arg,
|
||||
TensorHandle* tensor_handle, Device** result) {
|
||||
Device* cpu_device = ctx.HostCPU();
|
||||
string device_name;
|
||||
@@ -318,16 +345,19 @@ Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx,
|
||||
Device* device = tensor_handle->device();
|
||||
const bool is_tpu = device != nullptr && device->device_type() == "TPU";
|
||||
// int32 return values can be placed on TPUs.
|
||||
// int32 retrun values can be placed on device for eager operations.
|
||||
const bool use_host_memory =
|
||||
is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
|
||||
: MTypeFromDType(tensor_handle->dtype);
|
||||
is_tpu || (!op.is_function() && device != cpu_device &&
|
||||
!is_host_memory_arg)
|
||||
? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
|
||||
: MTypeFromDType(tensor_handle->dtype);
|
||||
if (use_host_memory) {
|
||||
*result = cpu_device;
|
||||
} else {
|
||||
// Eager ops executing as functions should have their preferred inputs set
|
||||
// to the op's device. This allows us to avoid expensive D2H copies if a
|
||||
// mirror of the tensor already exists on the op's device.
|
||||
if (!op.is_function() && device != nullptr && device != cpu_device) {
|
||||
if (!op.is_function() && device != cpu_device && !is_host_memory_arg) {
|
||||
device = absl::get<Device*>(op.Device());
|
||||
}
|
||||
*result = (device == nullptr ? cpu_device : device);
|
||||
@@ -840,17 +870,35 @@ Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) {
|
||||
return AddMixedTypeListAttrs(*wrapped_op, op_attrs, opdef);
|
||||
}
|
||||
|
||||
bool IntArgsAndRetvalsOnDevice(EagerOperation* op) {
|
||||
// Most TF ops expect and generate int32 tensors on the host (or a TPU/XLA
|
||||
// device). This is not the case with IteratorGetNext since it is possible to
|
||||
// build int32 datasets that produce outputs on device when using
|
||||
// prefetch_to_device.
|
||||
// When running call ops, by default we assume that the int32 outputs are on a
|
||||
// host (except for the XLA/TPU case). So we need to special case
|
||||
// IteratorGetNext such that its eager behavior matches the wrapped one.
|
||||
// TODO(b/208435025): Remove this if we end up deciding that int32 outputs
|
||||
// from IteratorGetNext should indeed live on host.
|
||||
return op->Name() == "IteratorGetNext";
|
||||
// Necessary condition to place int args/retvals on device but not sufficient.
|
||||
// For eager operations return values can be placed on the device for use
|
||||
// by subsequent eager ops. E.g.
|
||||
// with tf.device("/GPU:0"):
|
||||
// x = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
|
||||
// y = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
|
||||
// z = tf.bitwise.bitwise_and(x, y)
|
||||
// In the above example `z` can use the outputs of `x` and `y` without needing
|
||||
// an H2D copy if x and y are left on-device.
|
||||
bool IntArgsAndRetvalsOnDevice(EagerOperation* op,
|
||||
const KernelDef* kernel_def) {
|
||||
// We choose to leave `EagerConsts`
|
||||
// on HOST to avoid `shape` and other arguments that are traditionally pinned
|
||||
// to HostMemory from being placed on-device and then being copied to host via
|
||||
// an expensive D2H transfer.
|
||||
if (op->Name() == "_EagerConst") return false;
|
||||
|
||||
// Check if any of the Op's output_arg(s) are pinned to Host.
|
||||
if (kernel_def == nullptr) return false;
|
||||
const OpDef& op_def = OpRegistry::Global()->LookUp(op->Name())->op_def;
|
||||
for (const string& host_memory_arg : kernel_def->host_memory_arg()) {
|
||||
for (const auto& output_arg : op_def.output_arg()) {
|
||||
if (output_arg.name() == host_memory_arg) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<Fprint128> GetKernelCacheKey(
|
||||
@@ -941,6 +989,7 @@ Status GetOrCreateKernelAndDevice(
|
||||
core::RefCountPtr<KernelAndDevice>* out_kernel) {
|
||||
EagerContext& ctx = op->EagerContext();
|
||||
Device* device = absl::get<Device*>(op->Device());
|
||||
const KernelDef* kernel_def = nullptr;
|
||||
|
||||
// Set the EagerOperation's device prior to extracting the input_dev_ptrs to
|
||||
// avoid any redundant H2D/D2H copies.
|
||||
@@ -974,11 +1023,21 @@ Status GetOrCreateKernelAndDevice(
|
||||
input_dev_ptrs.reserve(op->Inputs().size());
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
Device* op_device = nullptr;
|
||||
const NodeDef* node_def = nullptr;
|
||||
if (!op->is_function()) {
|
||||
op_device = absl::get<Device*>(op->Device());
|
||||
node_def = &op->MutableAttrs()->BuildNodeDef();
|
||||
kernel_def = GetKernelDef(*op, node_def, op_device);
|
||||
}
|
||||
for (int i = 0, end = inputs->size(); i < end; ++i) {
|
||||
TensorHandle* input = (*inputs)[i];
|
||||
|
||||
Device* input_device;
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(*op, ctx, input, &input_device));
|
||||
bool is_host_memory_arg =
|
||||
IsHostMemoryArg(*op, node_def, op_device, kernel_def, i);
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(*op, ctx, is_host_memory_arg, input,
|
||||
&input_device));
|
||||
VLOG(1) << op->Name() << ":input:" << i << " " << input_device->name();
|
||||
input_dev_ptrs.push_back(input_device);
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
@@ -1093,8 +1152,13 @@ Status GetOrCreateKernelAndDevice(
|
||||
allow_small_function_optimizations = true;
|
||||
allow_control_flow_sync_execution = true;
|
||||
shape_inference_on_tfe_dialect_import = false;
|
||||
int_args_and_retvals_on_device = IntArgsAndRetvalsOnDevice(op);
|
||||
int_args_and_retvals_on_device =
|
||||
IntArgsAndRetvalsOnDevice(op, kernel_def);
|
||||
op = wrapped_op;
|
||||
if (int_args_and_retvals_on_device) {
|
||||
op->MutableAttrs()->Set(FunctionLibraryDefinition::kIntsOnDeviceAttr,
|
||||
true);
|
||||
}
|
||||
}
|
||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||
|
||||
|
||||
@@ -1101,6 +1101,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
opts.allow_small_function_optimizations = data->enable_sync_execution;
|
||||
opts.allow_control_flow_sync_execution =
|
||||
options.allow_control_flow_sync_execution;
|
||||
AttrValue ints_on_device_attr;
|
||||
ints_on_device_attr.set_b(options.int_args_and_retvals_on_device);
|
||||
shard.mutable_attr()->insert(
|
||||
{FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr});
|
||||
auto attrs = AttrSlice(&shard.attr());
|
||||
VLOG(1) << "Start instantiating component function " << unique_name
|
||||
<< " on device " << target;
|
||||
|
||||
@@ -751,9 +751,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
const OpDef& sig = fdef.signature();
|
||||
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
|
||||
|
||||
const AttrValue* attr_values_ints_on_device =
|
||||
attr_values.Find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
|
||||
bool ints_on_device =
|
||||
fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
|
||||
fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
|
||||
(fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
|
||||
fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) ||
|
||||
(attr_values_ints_on_device != nullptr &&
|
||||
attr_values_ints_on_device->b());
|
||||
|
||||
FunctionInstantiationHelper helper(get_function, result);
|
||||
Status s;
|
||||
|
||||
@@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import critical_section_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@@ -54,8 +55,17 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._m_2_by_2 = random_ops.random_uniform((2, 2))
|
||||
self._m_2_by_2_int32 = random_ops.random_uniform((2, 2),
|
||||
maxval=5,
|
||||
dtype=dtypes.int32)
|
||||
self._m_100_by_100 = random_ops.random_uniform((100, 100))
|
||||
self._m_100_by_100_int32 = random_ops.random_uniform((100, 100),
|
||||
maxval=5,
|
||||
dtype=dtypes.int32)
|
||||
self._m_1000_by_1000 = random_ops.random_uniform((1000, 1000))
|
||||
self._m_1000_by_1000_int32 = random_ops.random_uniform((1000, 1000),
|
||||
maxval=5,
|
||||
dtype=dtypes.int32)
|
||||
|
||||
def _get_benchmark_name(self):
|
||||
"""Copied from benchmarks_test.py."""
|
||||
@@ -94,24 +104,67 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func = lambda: math_ops.matmul(mat, mat)
|
||||
self._run(func, num_iters=5000)
|
||||
|
||||
def _benchmark_bitwise_and(self, mat, device):
|
||||
if device == GPU and not context.num_gpus():
|
||||
return
|
||||
with context.device(device):
|
||||
if device == GPU:
|
||||
mat = mat.gpu()
|
||||
func = lambda: bitwise_ops.bitwise_and(mat, mat)
|
||||
self._run(func, num_iters=5000)
|
||||
|
||||
def _benchmark_random_normal(self, device):
|
||||
if device == GPU and not context.num_gpus():
|
||||
return
|
||||
with context.device(device):
|
||||
|
||||
def func():
|
||||
mat = constant_op.constant([3], dtypes.int32)
|
||||
s = mat + mat
|
||||
random_ops.random_normal(shape=s)
|
||||
|
||||
self._run(func, num_iters=5000)
|
||||
|
||||
# This only needs to be tested on GPU where redundant data transfers can occur
|
||||
def benchmark_random_normal_GPU(self):
|
||||
self._benchmark_random_normal(GPU)
|
||||
|
||||
def benchmark_tf_matmul_2_by_2_CPU(self):
|
||||
self._benchmark_matmul(self._m_2_by_2, CPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_2_by_2_CPU(self):
|
||||
self._benchmark_bitwise_and(self._m_2_by_2_int32, CPU)
|
||||
|
||||
def benchmark_tf_matmul_2_by_2_GPU(self):
|
||||
self._benchmark_matmul(self._m_2_by_2, GPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_2_by_2_GPU(self):
|
||||
self._benchmark_bitwise_and(self._m_2_by_2_int32, GPU)
|
||||
|
||||
def benchmark_tf_matmul_100_by_100_CPU(self):
|
||||
self._benchmark_matmul(self._m_100_by_100, CPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_100_by_100_CPU(self):
|
||||
self._benchmark_bitwise_and(self._m_100_by_100_int32, CPU)
|
||||
|
||||
def benchmark_tf_matmul_100_by_100_GPU(self):
|
||||
self._benchmark_matmul(self._m_100_by_100, GPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_100_by_100_GPU(self):
|
||||
self._benchmark_bitwise_and(self._m_100_by_100_int32, GPU)
|
||||
|
||||
def benchmark_tf_matmul_1000_by_1000_CPU(self):
|
||||
self._benchmark_matmul(self._m_1000_by_1000, CPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_1000_by_1000_CPU(self):
|
||||
self._benchmark_bitwise_and(self._m_1000_by_1000_int32, CPU)
|
||||
|
||||
def benchmark_tf_matmul_1000_by_1000_GPU(self):
|
||||
self._benchmark_matmul(self._m_1000_by_1000, GPU)
|
||||
|
||||
def benchmark_tf_bitwise_and_1000_by_1000_GPU(self):
|
||||
self._benchmark_bitwise_and(self._m_1000_by_1000_int32, GPU)
|
||||
|
||||
|
||||
@test_util.with_eager_op_as_function
|
||||
class RunEagerOpAsFunctionTest(test.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user