scatter_update for resource variables

PiperOrigin-RevId: 173963715
This commit is contained in:
Alexandre Passos
2017-10-30 15:58:31 -07:00
committed by TensorFlower Gardener
parent 8f7903b4c3
commit 89120eb688
4 changed files with 104 additions and 4 deletions

View File

@@ -569,9 +569,11 @@ class ResourceScatterUpdateOp : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
// TODO(apassos) add the other types here.
#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD);
#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \

View File

@@ -311,7 +311,7 @@ the same location, their contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
</div>
resource: Should be from a `Variable` node.
@@ -319,4 +319,44 @@ indices: A tensor of indices into the first dimension of `ref`.
updates: A tensor of updated values to add to `ref`.
)doc");
REGISTER_OP("ResourceScatterUpdate")
.Input("resource: resource")
.Input("indices: Tindices")
.Input("updates: dtype")
.Attr("dtype: numbertype")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
ValidateVariableResourceHandle(c, &handle_shape_and_type));
ShapeHandle var_shape = handle_shape_and_type.shape;
ShapeHandle indices_shape = c->input(1);
ShapeHandle unused_updates_shape;
ShapeHandle concat;
ShapeHandle var_subshape;
TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
return Status::OK();
})
.Doc(R"doc(
Assigns sparse updates to the variable referenced by `resource`.
This operation computes
# Scalar indices
ref[indices, ...] = updates[...]
# Vector indices (for each i)
ref[indices[i], ...] = updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
resource: Should be from a `Variable` node.
indices: A tensor of indices into the first dimension of `ref`.
updates: A tensor of updated values to add to `ref`.
)doc");
} // namespace tensorflow

View File

@@ -483,6 +483,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.destroy_resource_op(var._handle,
ignore_lookup_error=False)
def testScatterUpdate(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
state_ops.scatter_update(v, [1], [3.0])
self.assertAllEqual([1.0, 3.0], v.numpy())
if __name__ == "__main__":
test.main()

View File

@@ -297,3 +297,55 @@ def count_up_to(ref, limit, name=None):
return gen_state_ops.count_up_to(ref, limit=limit, name=name)
return gen_state_ops.resource_count_up_to(
ref.handle, limit, T=ref.dtype, name=name)
def scatter_update(ref, indices, updates, use_locking=True, name=None):
# pylint: disable=line-too-long
r"""Applies sparse updates to a variable reference.
This operation computes
```python
# Scalar indices
ref[indices, ...] = updates[...]
# Vector indices (for each i)
ref[indices[i], ...] = updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
If values in `ref` is to be updated more than once, because there are
duplicate entries in `indices`, the order at which the updates happen
for each value is undefined.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
</div>
Args:
ref: A `Variable`.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to store in `ref`.
use_locking: An optional `bool`. Defaults to `True`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_update(ref, indices, updates,
use_locking=use_locking, name=name)
return gen_resource_variable_ops.resource_scatter_update(
ref.handle, indices, updates, name=name)