mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[tensorflow training input] If SparseTensors are used in batch* ops, ensure restoration.
This forces the ST restore op to be called if any tensors are accessed at the output of the batch, thus fixing a memory leak. Solution suggested by Derek Murray. Fixes #13999. PiperOrigin-RevId: 173904309
This commit is contained in:
committed by
TensorFlower Gardener
parent
7fd2616026
commit
85f8d92408
@@ -574,7 +574,23 @@ def _restore_sparse_tensors(stored_list, sparse_info_list):
|
||||
rank=(info.rank + 1).value)
|
||||
if info.sparse else s
|
||||
for (s, info) in zip(stored_list, sparse_info_list)]
|
||||
return tensors if received_sequence else tensors[0]
|
||||
has_st = any(isinstance(x, sparse_tensor.SparseTensor) for x in tensors)
|
||||
if has_st:
|
||||
t_values = [
|
||||
x.values if isinstance(x, sparse_tensor.SparseTensor)
|
||||
else x
|
||||
for x in tensors]
|
||||
with_deps = lambda x: control_flow_ops.with_dependencies(t_values, x)
|
||||
ensure_restore_tensors = [
|
||||
sparse_tensor.SparseTensor(indices=with_deps(x.indices),
|
||||
values=with_deps(x.values),
|
||||
dense_shape=with_deps(x.dense_shape))
|
||||
if isinstance(x, sparse_tensor.SparseTensor)
|
||||
else with_deps(x)
|
||||
for x in tensors]
|
||||
else:
|
||||
ensure_restore_tensors = tensors
|
||||
return ensure_restore_tensors if received_sequence else tensors[0]
|
||||
|
||||
|
||||
def _validate(tensor_list):
|
||||
|
||||
Reference in New Issue
Block a user