diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index a4db4abd7b0..217fb3b7811 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -569,9 +569,11 @@ class ResourceScatterUpdateOp : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
// TODO(apassos) add the other types here.
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
- REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
- scatter_op::UpdateOp::ADD);
+#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
+ scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
+ scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index c4802a1cc1e..cdfbec85cf1 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -311,7 +311,7 @@ the same location, their contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
-

+
resource: Should be from a `Variable` node.
@@ -319,4 +319,44 @@ indices: A tensor of indices into the first dimension of `ref`.
updates: A tensor of updated values to add to `ref`.
)doc");
+REGISTER_OP("ResourceScatterUpdate")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeAndType handle_shape_and_type;
+ TF_RETURN_IF_ERROR(
+ ValidateVariableResourceHandle(c, &handle_shape_and_type));
+ ShapeHandle var_shape = handle_shape_and_type.shape;
+ ShapeHandle indices_shape = c->input(1);
+
+ ShapeHandle unused_updates_shape;
+ ShapeHandle concat;
+ ShapeHandle var_subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
+ TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
+ TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Assigns sparse updates to the variable referenced by `resource`.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
+
+resource: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to add to `ref`.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 24ba1329f38..7922e3838f4 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -483,6 +483,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.destroy_resource_op(var._handle,
ignore_lookup_error=False)
+ def testScatterUpdate(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
+ state_ops.scatter_update(v, [1], [3.0])
+ self.assertAllEqual([1.0, 3.0], v.numpy())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 5b9ca7c0b96..dbab07da426 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -297,3 +297,55 @@ def count_up_to(ref, limit, name=None):
return gen_state_ops.count_up_to(ref, limit=limit, name=name)
return gen_state_ops.resource_count_up_to(
ref.handle, limit, T=ref.dtype, name=name)
+
+
+def scatter_update(ref, indices, updates, use_locking=True, name=None):
+ # pylint: disable=line-too-long
+ r"""Applies sparse updates to a variable reference.
+
+ This operation computes
+
+ ```python
+ # Scalar indices
+ ref[indices, ...] = updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
+ ```
+
+ This operation outputs `ref` after the update is done.
+ This makes it easier to chain operations that need to use the reset value.
+
+ If values in `ref` is to be updated more than once, because there are
+ duplicate entries in `indices`, the order at which the updates happen
+ for each value is undefined.
+
+ Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+
+

+
+
+ Args:
+ ref: A `Variable`.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into the first dimension of `ref`.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to store in `ref`.
+ use_locking: An optional `bool`. Defaults to `True`.
+ If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_update(ref, indices, updates,
+ use_locking=use_locking, name=name)
+ return gen_resource_variable_ops.resource_scatter_update(
+ ref.handle, indices, updates, name=name)