[tensorflow training input] If SparseTensors are used in batch* ops, ensure restoration.

This forces the ST restore op to be called if any tensors are accessed at the output
of the batch, thus fixing a memory leak.

Solution suggested by Derek Murray.

Fixes #13999.

PiperOrigin-RevId: 173904309
This commit is contained in:
Eugene Brevdo
2017-10-30 09:31:50 -07:00
committed by TensorFlower Gardener
parent 7fd2616026
commit 85f8d92408

View File

@@ -574,7 +574,23 @@ def _restore_sparse_tensors(stored_list, sparse_info_list):
rank=(info.rank + 1).value)
if info.sparse else s
for (s, info) in zip(stored_list, sparse_info_list)]
return tensors if received_sequence else tensors[0]
has_st = any(isinstance(x, sparse_tensor.SparseTensor) for x in tensors)
if has_st:
t_values = [
x.values if isinstance(x, sparse_tensor.SparseTensor)
else x
for x in tensors]
with_deps = lambda x: control_flow_ops.with_dependencies(t_values, x)
ensure_restore_tensors = [
sparse_tensor.SparseTensor(indices=with_deps(x.indices),
values=with_deps(x.values),
dense_shape=with_deps(x.dense_shape))
if isinstance(x, sparse_tensor.SparseTensor)
else with_deps(x)
for x in tensors]
else:
ensure_restore_tensors = tensors
return ensure_restore_tensors if received_sequence else tensors[0]
def _validate(tensor_list):