From 35e982d715860f6c4bc22e9b264cfe63cd1c4fbb Mon Sep 17 00:00:00 2001 From: Xinyi Wang Date: Sat, 29 Oct 2022 21:37:51 -0700 Subject: [PATCH] Rename file to reflect what it does. PiperOrigin-RevId: 484815575 --- .../python/distribute/failure_handling/BUILD | 8 ++--- .../failure_handling/failure_handler_test.py | 4 +-- .../failure_handling/failure_handling.py | 31 +++++++++---------- .../{gce_util.py => failure_handling_util.py} | 0 .../gce_failure_handler_test.py | 14 ++++----- tensorflow/tools/pip_package/BUILD | 2 +- 6 files changed, 29 insertions(+), 30 deletions(-) rename tensorflow/python/distribute/failure_handling/{gce_util.py => failure_handling_util.py} (100%) diff --git a/tensorflow/python/distribute/failure_handling/BUILD b/tensorflow/python/distribute/failure_handling/BUILD index b91d134d3aa..cf31f307908 100644 --- a/tensorflow/python/distribute/failure_handling/BUILD +++ b/tensorflow/python/distribute/failure_handling/BUILD @@ -14,7 +14,7 @@ py_library( ], srcs_version = "PY3", deps = [ - ":gce_util", + ":failure_handling_util", "//tensorflow/python:lib", "//tensorflow/python:variables", "//tensorflow/python/checkpoint", @@ -34,9 +34,9 @@ py_library( ) py_library( - name = "gce_util", + name = "failure_handling_util", srcs = [ - "gce_util.py", + "failure_handling_util.py", ], srcs_version = "PY3", deps = [ @@ -82,7 +82,7 @@ tf_py_test( ], deps = [ ":failure_handling_lib", - ":gce_util", + ":failure_handling_util", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_worker_test_base", diff --git a/tensorflow/python/distribute/failure_handling/failure_handler_test.py b/tensorflow/python/distribute/failure_handling/failure_handler_test.py index ba33d656d37..17bcd9d4063 100644 --- a/tensorflow/python/distribute/failure_handling/failure_handler_test.py +++ b/tensorflow/python/distribute/failure_handling/failure_handler_test.py @@ -34,7 +34,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import test_util from tensorflow.python.distribute.failure_handling import failure_handling -from tensorflow.python.distribute.failure_handling import gce_util +from tensorflow.python.distribute.failure_handling import failure_handling_util from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl @@ -137,7 +137,7 @@ class PreemptionCheckpointTest(test.TestCase, parameterized.TestCase): def __call__(self): return self.v.read_value() - with mock.patch.object(gce_util, 'on_gcp', lambda: False): + with mock.patch.object(failure_handling_util, 'on_gcp', lambda: False): with strategy.scope(): model = Model() diff --git a/tensorflow/python/distribute/failure_handling/failure_handling.py b/tensorflow/python/distribute/failure_handling/failure_handling.py index 33c4037c323..2a76fdbcb72 100644 --- a/tensorflow/python/distribute/failure_handling/failure_handling.py +++ b/tensorflow/python/distribute/failure_handling/failure_handling.py @@ -31,7 +31,7 @@ from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_ from tensorflow.python.checkpoint import checkpoint as checkpoint_lib from tensorflow.python.checkpoint import checkpoint_management from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute.failure_handling import gce_util +from tensorflow.python.distribute.failure_handling import failure_handling_util from tensorflow.python.eager import context from tensorflow.python.eager import monitoring from tensorflow.python.framework import constant_op @@ -172,11 +172,11 @@ class GcpGpuTerminationConfig(TerminationConfig): termination_watcher_fn=None, exit_fn=None, grace_period=None): - self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce - self.exit_fn = exit_fn or gce_util.gce_exit_fn + self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce + self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn self.grace_period = ( - grace_period - if grace_period or grace_period == 0 else gce_util.GRACE_PERIOD_GCE) + grace_period if grace_period or grace_period == 0 else + failure_handling_util.GRACE_PERIOD_GCE) class GcpCpuTerminationConfig(TerminationConfig): @@ -187,8 +187,8 @@ class GcpCpuTerminationConfig(TerminationConfig): termination_watcher_fn=None, exit_fn=None, grace_period=None): - self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce - self.exit_fn = exit_fn or gce_util.gce_exit_fn + self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce + self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn self.grace_period = grace_period or 0 @@ -211,12 +211,12 @@ def _complete_config_for_environment(platform_device, termination_config): if not termination_config: termination_config = TerminationConfig() - if platform_device is gce_util.PlatformDevice.GCE_GPU: + if platform_device is failure_handling_util.PlatformDevice.GCE_GPU: return GcpGpuTerminationConfig(termination_config.termination_watcher_fn, termination_config.exit_fn, termination_config.grace_period) - elif platform_device is gce_util.PlatformDevice.GCE_CPU: + elif platform_device is failure_handling_util.PlatformDevice.GCE_CPU: return GcpCpuTerminationConfig(termination_config.termination_watcher_fn, termination_config.exit_fn, termination_config.grace_period) @@ -433,17 +433,16 @@ class PreemptionCheckpointHandler(object): self._checkpoint_or_checkpoint_manager = checkpoint_or_checkpoint_manager self._checkpoint_dir = checkpoint_dir - self._platform_device = gce_util.detect_platform() - if self._platform_device in (gce_util.PlatformDevice.GCE_TPU, - gce_util.PlatformDevice.GCE_CPU): + self._platform_device = failure_handling_util.detect_platform() + if self._platform_device in (failure_handling_util.PlatformDevice.GCE_TPU, + failure_handling_util.PlatformDevice.GCE_CPU): # While running MultiWorkerMirroredStrategy training with GPUs and CPUs # are the same on Borg, GCE CPU VM and GPU VM are different in terms # of live migration, grace period, etc. We can make it work upon request. raise NotImplementedError('PreemptionCheckpointHandler does not support ' 'usage with TPU or CPU device on GCP.') - # TODO(wxinyi): update name of gce_util. - elif self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU: + elif self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU: if ENABLE_TESTING_FOR_TPU: self._initialize_for_tpu_strategy() @@ -784,7 +783,7 @@ class PreemptionCheckpointHandler(object): # the dominant use case for TPU user. Besides, passing in a multi-step # `distributed_train_function` will require the user to track their own # training steps. - if self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU: + if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU: return self._run_for_tpu(distributed_train_function, *args, **kwargs) else: return self._run_for_multi_worker_mirrored(distributed_train_function, @@ -845,7 +844,7 @@ class PreemptionCheckpointHandler(object): 'PreemptionCheckpointHandler Saving Checkpoint').increase_by(1) logging.info('PreemptionCheckpointHandler: Starting saving a checkpoint.') - if self._platform_device != gce_util.PlatformDevice.INTERNAL_TPU: + if self._platform_device != failure_handling_util.PlatformDevice.INTERNAL_TPU: self._checkpointed_runs.assign(self.total_run_calls) start_time = time.monotonic() diff --git a/tensorflow/python/distribute/failure_handling/gce_util.py b/tensorflow/python/distribute/failure_handling/failure_handling_util.py similarity index 100% rename from tensorflow/python/distribute/failure_handling/gce_util.py rename to tensorflow/python/distribute/failure_handling/failure_handling_util.py diff --git a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py index 8b8467ec52a..30936c3ba5d 100644 --- a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py +++ b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py @@ -33,7 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import test_util from tensorflow.python.distribute.failure_handling import failure_handling -from tensorflow.python.distribute.failure_handling import gce_util +from tensorflow.python.distribute.failure_handling import failure_handling_util from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.module import module @@ -129,10 +129,10 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase): return False with mock.patch.object( - gce_util, 'termination_watcher_function_gce', + failure_handling_util, 'termination_watcher_function_gce', mock_termination_watcher_function_gce), mock.patch.object( - gce_util, 'detect_platform', - lambda: gce_util.PlatformDevice.GCE_GPU): + failure_handling_util, 'detect_platform', + lambda: failure_handling_util.PlatformDevice.GCE_GPU): class Model(module.Module): @@ -235,9 +235,9 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase): except urllib.error.URLError as e: if 'Temporary failure in name resolution' in e.message: # This is caused by a weird flakiness that mock.patch does not - # correctly patch gce_util.request_compute_metadata, a real request - # is attempted, and an error is hit in - # gce_util.request_compute_metadata + # correctly patch failure_handling_util.request_compute_metadata, a + # real request is attempted, and an error is hit in + # failure_handling_util.request_compute_metadata logging.warning('Hit a mock issue.') return diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 79719f4f049..b0f0bc95e83 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -133,7 +133,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute/failure_handling:check_preemption_py", "//tensorflow/python/distribute/failure_handling:failure_handling_lib", - "//tensorflow/python/distribute/failure_handling:gce_util", + "//tensorflow/python/distribute/failure_handling:failure_handling_util", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/eager:eager_pip", "//tensorflow/python/keras:combinations",