mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Rename file to reflect what it does.
PiperOrigin-RevId: 484815575
This commit is contained in:
committed by
TensorFlower Gardener
parent
244da5285a
commit
35e982d715
@@ -14,7 +14,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":gce_util",
|
||||
":failure_handling_util",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/checkpoint",
|
||||
@@ -34,9 +34,9 @@ py_library(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gce_util",
|
||||
name = "failure_handling_util",
|
||||
srcs = [
|
||||
"gce_util.py",
|
||||
"failure_handling_util.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
@@ -82,7 +82,7 @@ tf_py_test(
|
||||
],
|
||||
deps = [
|
||||
":failure_handling_lib",
|
||||
":gce_util",
|
||||
":failure_handling_util",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
|
||||
@@ -34,7 +34,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute.failure_handling import failure_handling
|
||||
from tensorflow.python.distribute.failure_handling import gce_util
|
||||
from tensorflow.python.distribute.failure_handling import failure_handling_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@@ -137,7 +137,7 @@ class PreemptionCheckpointTest(test.TestCase, parameterized.TestCase):
|
||||
def __call__(self):
|
||||
return self.v.read_value()
|
||||
|
||||
with mock.patch.object(gce_util, 'on_gcp', lambda: False):
|
||||
with mock.patch.object(failure_handling_util, 'on_gcp', lambda: False):
|
||||
|
||||
with strategy.scope():
|
||||
model = Model()
|
||||
|
||||
@@ -31,7 +31,7 @@ from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_
|
||||
from tensorflow.python.checkpoint import checkpoint as checkpoint_lib
|
||||
from tensorflow.python.checkpoint import checkpoint_management
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute.failure_handling import gce_util
|
||||
from tensorflow.python.distribute.failure_handling import failure_handling_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import constant_op
|
||||
@@ -172,11 +172,11 @@ class GcpGpuTerminationConfig(TerminationConfig):
|
||||
termination_watcher_fn=None,
|
||||
exit_fn=None,
|
||||
grace_period=None):
|
||||
self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce
|
||||
self.exit_fn = exit_fn or gce_util.gce_exit_fn
|
||||
self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce
|
||||
self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
|
||||
self.grace_period = (
|
||||
grace_period
|
||||
if grace_period or grace_period == 0 else gce_util.GRACE_PERIOD_GCE)
|
||||
grace_period if grace_period or grace_period == 0 else
|
||||
failure_handling_util.GRACE_PERIOD_GCE)
|
||||
|
||||
|
||||
class GcpCpuTerminationConfig(TerminationConfig):
|
||||
@@ -187,8 +187,8 @@ class GcpCpuTerminationConfig(TerminationConfig):
|
||||
termination_watcher_fn=None,
|
||||
exit_fn=None,
|
||||
grace_period=None):
|
||||
self.termination_watcher_fn = termination_watcher_fn or gce_util.termination_watcher_function_gce
|
||||
self.exit_fn = exit_fn or gce_util.gce_exit_fn
|
||||
self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce
|
||||
self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
|
||||
self.grace_period = grace_period or 0
|
||||
|
||||
|
||||
@@ -211,12 +211,12 @@ def _complete_config_for_environment(platform_device, termination_config):
|
||||
if not termination_config:
|
||||
termination_config = TerminationConfig()
|
||||
|
||||
if platform_device is gce_util.PlatformDevice.GCE_GPU:
|
||||
if platform_device is failure_handling_util.PlatformDevice.GCE_GPU:
|
||||
return GcpGpuTerminationConfig(termination_config.termination_watcher_fn,
|
||||
termination_config.exit_fn,
|
||||
termination_config.grace_period)
|
||||
|
||||
elif platform_device is gce_util.PlatformDevice.GCE_CPU:
|
||||
elif platform_device is failure_handling_util.PlatformDevice.GCE_CPU:
|
||||
return GcpCpuTerminationConfig(termination_config.termination_watcher_fn,
|
||||
termination_config.exit_fn,
|
||||
termination_config.grace_period)
|
||||
@@ -433,17 +433,16 @@ class PreemptionCheckpointHandler(object):
|
||||
self._checkpoint_or_checkpoint_manager = checkpoint_or_checkpoint_manager
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
|
||||
self._platform_device = gce_util.detect_platform()
|
||||
if self._platform_device in (gce_util.PlatformDevice.GCE_TPU,
|
||||
gce_util.PlatformDevice.GCE_CPU):
|
||||
self._platform_device = failure_handling_util.detect_platform()
|
||||
if self._platform_device in (failure_handling_util.PlatformDevice.GCE_TPU,
|
||||
failure_handling_util.PlatformDevice.GCE_CPU):
|
||||
# While running MultiWorkerMirroredStrategy training with GPUs and CPUs
|
||||
# are the same on Borg, GCE CPU VM and GPU VM are different in terms
|
||||
# of live migration, grace period, etc. We can make it work upon request.
|
||||
raise NotImplementedError('PreemptionCheckpointHandler does not support '
|
||||
'usage with TPU or CPU device on GCP.')
|
||||
|
||||
# TODO(wxinyi): update name of gce_util.
|
||||
elif self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU:
|
||||
elif self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
|
||||
if ENABLE_TESTING_FOR_TPU:
|
||||
self._initialize_for_tpu_strategy()
|
||||
|
||||
@@ -784,7 +783,7 @@ class PreemptionCheckpointHandler(object):
|
||||
# the dominant use case for TPU user. Besides, passing in a multi-step
|
||||
# `distributed_train_function` will require the user to track their own
|
||||
# training steps.
|
||||
if self._platform_device == gce_util.PlatformDevice.INTERNAL_TPU:
|
||||
if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
|
||||
return self._run_for_tpu(distributed_train_function, *args, **kwargs)
|
||||
else:
|
||||
return self._run_for_multi_worker_mirrored(distributed_train_function,
|
||||
@@ -845,7 +844,7 @@ class PreemptionCheckpointHandler(object):
|
||||
'PreemptionCheckpointHandler Saving Checkpoint').increase_by(1)
|
||||
logging.info('PreemptionCheckpointHandler: Starting saving a checkpoint.')
|
||||
|
||||
if self._platform_device != gce_util.PlatformDevice.INTERNAL_TPU:
|
||||
if self._platform_device != failure_handling_util.PlatformDevice.INTERNAL_TPU:
|
||||
self._checkpointed_runs.assign(self.total_run_calls)
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
@@ -33,7 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute.failure_handling import failure_handling
|
||||
from tensorflow.python.distribute.failure_handling import gce_util
|
||||
from tensorflow.python.distribute.failure_handling import failure_handling_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.module import module
|
||||
@@ -129,10 +129,10 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
|
||||
return False
|
||||
|
||||
with mock.patch.object(
|
||||
gce_util, 'termination_watcher_function_gce',
|
||||
failure_handling_util, 'termination_watcher_function_gce',
|
||||
mock_termination_watcher_function_gce), mock.patch.object(
|
||||
gce_util, 'detect_platform',
|
||||
lambda: gce_util.PlatformDevice.GCE_GPU):
|
||||
failure_handling_util, 'detect_platform',
|
||||
lambda: failure_handling_util.PlatformDevice.GCE_GPU):
|
||||
|
||||
class Model(module.Module):
|
||||
|
||||
@@ -235,9 +235,9 @@ class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
|
||||
except urllib.error.URLError as e:
|
||||
if 'Temporary failure in name resolution' in e.message:
|
||||
# This is caused by a weird flakiness that mock.patch does not
|
||||
# correctly patch gce_util.request_compute_metadata, a real request
|
||||
# is attempted, and an error is hit in
|
||||
# gce_util.request_compute_metadata
|
||||
# correctly patch failure_handling_util.request_compute_metadata, a
|
||||
# real request is attempted, and an error is hit in
|
||||
# failure_handling_util.request_compute_metadata
|
||||
logging.warning('Hit a mock issue.')
|
||||
return
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute/failure_handling:check_preemption_py",
|
||||
"//tensorflow/python/distribute/failure_handling:failure_handling_lib",
|
||||
"//tensorflow/python/distribute/failure_handling:gce_util",
|
||||
"//tensorflow/python/distribute/failure_handling:failure_handling_util",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/eager:eager_pip",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
|
||||
Reference in New Issue
Block a user