From fa84a1e83cf7087fda43c0bf648bedaa69cd672e Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Thu, 24 Jun 2021 12:16:18 -0700 Subject: [PATCH] Fix forward test failure in multi-GPU tests PiperOrigin-RevId: 381309867 Change-Id: I4a58cca3e84c6816d5a501a98ed3850940ca5f6d --- tensorflow/python/distribute/ps_values.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/distribute/ps_values.py b/tensorflow/python/distribute/ps_values.py index 465cd3a51f1..59f34424248 100644 --- a/tensorflow/python/distribute/ps_values.py +++ b/tensorflow/python/distribute/ps_values.py @@ -200,6 +200,8 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor): # TODO(josh11b): Test saving & restoring. def _gather_saveables_for_checkpoint(self): + if isinstance(self._v, CachingVariable): + return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access return {trackable.VARIABLE_VALUE_KEY: self._v} def _map_resources(self, save_options):