Rename file to reflect what it does.

PiperOrigin-RevId: 484815575
This commit is contained in:
Xinyi Wang
2022-10-29 21:37:51 -07:00
committed by TensorFlower Gardener
parent 244da5285a
commit 35e982d715
6 changed files with 29 additions and 30 deletions

View File

@@ -14,7 +14,7 @@ py_library(
],
srcs_version = "PY3",
deps = [
":gce_util",
":failure_handling_util",
"//tensorflow/python:lib",
"//tensorflow/python:variables",
"//tensorflow/python/checkpoint",
@@ -34,9 +34,9 @@ py_library(
)
py_library(
name = "gce_util",
name = "failure_handling_util",
srcs = [
"gce_util.py",
"failure_handling_util.py",
],
srcs_version = "PY3",
deps = [
@@ -82,7 +82,7 @@ tf_py_test(
],
deps = [
":failure_handling_lib",
":gce_util",
":failure_handling_util",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base",

View File

@@ -34,7 +34,7 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.failure_handling import failure_handling
from tensorflow.python.distribute.failure_handling import gce_util
from tensorflow.python.distribute.failure_handling import failure_handling_util
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
@@ -137,7 +137,7 @@ class PreemptionCheckpointTest(test.TestCase, parameterized.TestCase):
def __call__(self):
return self.v.read_value()
with mock.patch.object(gce_util, 'on_gcp', lambda: False):
with mock.patch.object(failure_handling_util, 'on_gcp', lambda: False):
with strategy.scope():
model = Model()

View File

@@ -31,7 +31,7 @@ from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_
from tensorflow.python.checkpoint import checkpoint as checkpoint_lib
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute.failure_handling import gce_util
from tensorflow.python.distribute.failure_handling import failure_handling_util
from tensorflow.python.eager import context
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
@@ -172,11 +172,11 @@ class GcpGpuTerminationConfig(TerminationConfig):
termination_watcher_fn=None,
exit_fn=None,
grace_period=None):
self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce
self.exit_fn = exit_fn or gce_util.gce_exit_fn
self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce
self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
self.grace_period = (
grace_period
if grace_period or grace_period == 0 else gce_util.GRACE_PERIOD_GCE)
grace_period if grace_period or grace_period == 0 else
failure_handling_util.GRACE_PERIOD_GCE)
class GcpCpuTerminationConfig(TerminationConfig):
@@ -187,8 +187,8 @@ class GcpCpuTerminationConfig(TerminationConfig):
termination_watcher_fn=None,
exit_fn=None,
grace_period=None):
self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce
self.exit_fn = exit_fn or gce_util.gce_exit_fn
self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce
self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
self.grace_period = grace_period or 0
@@ -211,12 +211,12 @@ def _complete_config_for_environment(platform_device, termination_config):
if not termination_config:
termination_config = TerminationConfig()
if platform_device is gce_util.PlatformDevice.GCE_GPU:
if platform_device is failure_handling_util.PlatformDevice.GCE_GPU:
return GcpGpuTerminationConfig(termination_config.termination_watcher_fn,
termination_config.exit_fn,
termination_config.grace_period)
elif platform_device is gce_util.PlatformDevice.GCE_CPU:
elif platform_device is failure_handling_util.PlatformDevice.GCE_CPU:
return GcpCpuTerminationConfig(termination_config.termination_watcher_fn,
termination_config.exit_fn,
termination_config.grace_period)
@@ -433,17 +433,16 @@ class PreemptionCheckpointHandler(object):
self._checkpoint_or_checkpoint_manager = checkpoint_or_checkpoint_manager
self._checkpoint_dir = checkpoint_dir
self._platform_device = gce_util.detect_platform()
if self._platform_device in (gce_util.PlatformDevice.GCE_TPU,
gce_util.PlatformDevice.GCE_CPU):
self._platform_device = failure_handling_util.detect_platform()
if self._platform_device in (failure_handling_util.PlatformDevice.GCE_TPU,
failure_handling_util.PlatformDevice.GCE_CPU):
# While running MultiWorkerMirroredStrategy training with GPUs and CPUs
# are the same on Borg, GCE CPU VM and GPU VM are different in terms
# of live migration, grace period, etc. We can make it work upon request.
raise NotImplementedError('PreemptionCheckpointHandler does not support '
'usage with TPU or CPU device on GCP.')
# TODO(wxinyi): update name of gce_util.
elif self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU:
elif self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
if ENABLE_TESTING_FOR_TPU:
self._initialize_for_tpu_strategy()
@@ -784,7 +783,7 @@ class PreemptionCheckpointHandler(object):
# the dominant use case for TPU user. Besides, passing in a multi-step
# `distributed_train_function` will require the user to track their own
# training steps.
if self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU:
if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
return self._run_for_tpu(distributed_train_function, *args, **kwargs)
else:
return self._run_for_multi_worker_mirrored(distributed_train_function,
@@ -845,7 +844,7 @@ class PreemptionCheckpointHandler(object):
'PreemptionCheckpointHandler Saving Checkpoint').increase_by(1)
logging.info('PreemptionCheckpointHandler: Starting saving a checkpoint.')
if self._platform_device != gce_util.PlatformDevice.INTERNAL_TPU:
if self._platform_device != failure_handling_util.PlatformDevice.INTERNAL_TPU:
self._checkpointed_runs.assign(self.total_run_calls)
start_time = time.monotonic()

View File

@@ -33,7 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.failure_handling import failure_handling
from tensorflow.python.distribute.failure_handling import gce_util
from tensorflow.python.distribute.failure_handling import failure_handling_util
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.module import module
@@ -129,10 +129,10 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
return False
with mock.patch.object(
gce_util, 'termination_watcher_function_gce',
failure_handling_util, 'termination_watcher_function_gce',
mock_termination_watcher_function_gce), mock.patch.object(
gce_util, 'detect_platform',
lambda: gce_util.PlatformDevice.GCE_GPU):
failure_handling_util, 'detect_platform',
lambda: failure_handling_util.PlatformDevice.GCE_GPU):
class Model(module.Module):
@@ -235,9 +235,9 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
except urllib.error.URLError as e:
if 'Temporary failure in name resolution' in e.message:
# This is caused by a weird flakiness that mock.patch does not
# correctly patch gce_util.request_compute_metadata, a real request
# is attempted, and an error is hit in
# gce_util.request_compute_metadata
# correctly patch failure_handling_util.request_compute_metadata, a
# real request is attempted, and an error is hit in
# failure_handling_util.request_compute_metadata
logging.warning('Hit a mock issue.')
return

View File

@@ -133,7 +133,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute/failure_handling:check_preemption_py",
"//tensorflow/python/distribute/failure_handling:failure_handling_lib",
"//tensorflow/python/distribute/failure_handling:gce_util",
"//tensorflow/python/distribute/failure_handling:failure_handling_util",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/eager:eager_pip",
"//tensorflow/python/keras:combinations",