mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
scatter_update for resource variables
PiperOrigin-RevId: 173963715
This commit is contained in:
committed by
TensorFlower Gardener
parent
8f7903b4c3
commit
89120eb688
@@ -569,9 +569,11 @@ class ResourceScatterUpdateOp : public OpKernel {
|
||||
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
|
||||
|
||||
// TODO(apassos) add the other types here.
|
||||
#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
|
||||
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
|
||||
scatter_op::UpdateOp::ADD);
|
||||
#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
|
||||
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
|
||||
scatter_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
|
||||
scatter_op::UpdateOp::ASSIGN);
|
||||
|
||||
// Registers CPU kernels.
|
||||
#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
|
||||
|
||||
@@ -311,7 +311,7 @@ the same location, their contributions add.
|
||||
Requires `updates.shape = indices.shape + ref.shape[1:]`.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
|
||||
<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
|
||||
</div>
|
||||
|
||||
resource: Should be from a `Variable` node.
|
||||
@@ -319,4 +319,44 @@ indices: A tensor of indices into the first dimension of `ref`.
|
||||
updates: A tensor of updated values to add to `ref`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ResourceScatterUpdate")
|
||||
.Input("resource: resource")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: dtype")
|
||||
.Attr("dtype: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeAndType handle_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateVariableResourceHandle(c, &handle_shape_and_type));
|
||||
ShapeHandle var_shape = handle_shape_and_type.shape;
|
||||
ShapeHandle indices_shape = c->input(1);
|
||||
|
||||
ShapeHandle unused_updates_shape;
|
||||
ShapeHandle concat;
|
||||
ShapeHandle var_subshape;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Assigns sparse updates to the variable referenced by `resource`.
|
||||
|
||||
This operation computes
|
||||
|
||||
# Scalar indices
|
||||
ref[indices, ...] = updates[...]
|
||||
|
||||
# Vector indices (for each i)
|
||||
ref[indices[i], ...] = updates[i, ...]
|
||||
|
||||
# High rank indices (for each i, ..., j)
|
||||
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
|
||||
|
||||
resource: Should be from a `Variable` node.
|
||||
indices: A tensor of indices into the first dimension of `ref`.
|
||||
updates: A tensor of updated values to add to `ref`.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@@ -483,6 +483,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
resource_variable_ops.destroy_resource_op(var._handle,
|
||||
ignore_lookup_error=False)
|
||||
|
||||
def testScatterUpdate(self):
|
||||
with context.eager_mode():
|
||||
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
|
||||
state_ops.scatter_update(v, [1], [3.0])
|
||||
self.assertAllEqual([1.0, 3.0], v.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
@@ -297,3 +297,55 @@ def count_up_to(ref, limit, name=None):
|
||||
return gen_state_ops.count_up_to(ref, limit=limit, name=name)
|
||||
return gen_state_ops.resource_count_up_to(
|
||||
ref.handle, limit, T=ref.dtype, name=name)
|
||||
|
||||
|
||||
def scatter_update(ref, indices, updates, use_locking=True, name=None):
|
||||
# pylint: disable=line-too-long
|
||||
r"""Applies sparse updates to a variable reference.
|
||||
|
||||
This operation computes
|
||||
|
||||
```python
|
||||
# Scalar indices
|
||||
ref[indices, ...] = updates[...]
|
||||
|
||||
# Vector indices (for each i)
|
||||
ref[indices[i], ...] = updates[i, ...]
|
||||
|
||||
# High rank indices (for each i, ..., j)
|
||||
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
|
||||
```
|
||||
|
||||
This operation outputs `ref` after the update is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
|
||||
If values in `ref` is to be updated more than once, because there are
|
||||
duplicate entries in `indices`, the order at which the updates happen
|
||||
for each value is undefined.
|
||||
|
||||
Requires `updates.shape = indices.shape + ref.shape[1:]`.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
|
||||
</div>
|
||||
|
||||
Args:
|
||||
ref: A `Variable`.
|
||||
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
|
||||
A tensor of indices into the first dimension of `ref`.
|
||||
updates: A `Tensor`. Must have the same type as `ref`.
|
||||
A tensor of updated values to store in `ref`.
|
||||
use_locking: An optional `bool`. Defaults to `True`.
|
||||
If True, the assignment will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
Same as `ref`. Returned as a convenience for operations that want
|
||||
to use the updated values after the update is done.
|
||||
"""
|
||||
if ref.dtype._is_ref_dtype:
|
||||
return gen_state_ops.scatter_update(ref, indices, updates,
|
||||
use_locking=use_locking, name=name)
|
||||
return gen_resource_variable_ops.resource_scatter_update(
|
||||
ref.handle, indices, updates, name=name)
|
||||
|
||||
Reference in New Issue
Block a user