diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index a4db4abd7b0..217fb3b7811 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -569,9 +569,11 @@ class ResourceScatterUpdateOp : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); // TODO(apassos) add the other types here. -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ - REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ - scatter_op::UpdateOp::ADD); +#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ + scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ + scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. #define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index c4802a1cc1e..cdfbec85cf1 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -311,7 +311,7 @@ the same location, their contributions add. Requires `updates.shape = indices.shape + ref.shape[1:]`.
- +
resource: Should be from a `Variable` node. @@ -319,4 +319,44 @@ indices: A tensor of indices into the first dimension of `ref`. updates: A tensor of updated values to add to `ref`. )doc"); +REGISTER_OP("ResourceScatterUpdate") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn([](InferenceContext* c) { + ShapeAndType handle_shape_and_type; + TF_RETURN_IF_ERROR( + ValidateVariableResourceHandle(c, &handle_shape_and_type)); + ShapeHandle var_shape = handle_shape_and_type.shape; + ShapeHandle indices_shape = c->input(1); + + ShapeHandle unused_updates_shape; + ShapeHandle concat; + ShapeHandle var_subshape; + TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); + TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); + TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); + return Status::OK(); + }) + .Doc(R"doc( +Assigns sparse updates to the variable referenced by `resource`. + +This operation computes + + # Scalar indices + ref[indices, ...] = updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] = updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + +resource: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to add to `ref`. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 24ba1329f38..7922e3838f4 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -483,6 +483,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.destroy_resource_op(var._handle, ignore_lookup_error=False) + def testScatterUpdate(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") + state_ops.scatter_update(v, [1], [3.0]) + self.assertAllEqual([1.0, 3.0], v.numpy()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 5b9ca7c0b96..dbab07da426 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -297,3 +297,55 @@ def count_up_to(ref, limit, name=None): return gen_state_ops.count_up_to(ref, limit=limit, name=name) return gen_state_ops.resource_count_up_to( ref.handle, limit, T=ref.dtype, name=name) + + +def scatter_update(ref, indices, updates, use_locking=True, name=None): + # pylint: disable=line-too-long + r"""Applies sparse updates to a variable reference. + + This operation computes + + ```python + # Scalar indices + ref[indices, ...] = updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] = updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + ``` + + This operation outputs `ref` after the update is done. + This makes it easier to chain operations that need to use the reset value. + + If values in `ref` is to be updated more than once, because there are + duplicate entries in `indices`, the order at which the updates happen + for each value is undefined. + + Requires `updates.shape = indices.shape + ref.shape[1:]`. + +
+ +
+ + Args: + ref: A `Variable`. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A tensor of indices into the first dimension of `ref`. + updates: A `Tensor`. Must have the same type as `ref`. + A tensor of updated values to store in `ref`. + use_locking: An optional `bool`. Defaults to `True`. + If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_update(ref, indices, updates, + use_locking=use_locking, name=name) + return gen_resource_variable_ops.resource_scatter_update( + ref.handle, indices, updates, name=name)