diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index a3e3756007b..8e90fd4ec61 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -259,6 +259,10 @@ class DNNClassifier(estimator.Estimator): whose `value` is a `Tensor`. Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, @@ -392,6 +396,10 @@ class DNNRegressor(estimator.Estimator): whose `value` is a `Tensor`. Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index ff4ecee5c02..3c61bd5b07b 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -278,6 +278,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): whose `value` is a `Tensor`. Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, @@ -438,6 +442,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator): whose `value` is a `Tensor`. Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 3338f8ee2c6..8658ee38e99 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -184,6 +184,10 @@ class LinearClassifier(estimator.Estimator): whose `value` is a `Tensor`. Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, @@ -300,6 +304,10 @@ class LinearRegressor(estimator.Estimator): key=column.name, value=a `Tensor` Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index f198b051cfb..6243cfc118b 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -29,6 +29,7 @@ import six from tensorflow.core.framework import summary_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util @@ -87,6 +88,10 @@ class Estimator(object): None of `Estimator`'s methods can be overridden in subclasses (its constructor enforces this). Subclasses should use `model_fn` to configure the base class, and may add methods implementing specialized functionality. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility """ def __init__(self, model_fn, model_dir=None, config=None, params=None): @@ -129,10 +134,15 @@ class Estimator(object): Keys are names of parameters, values are basic python types. Raises: + RuntimeError: If eager execution is enabled. ValueError: parameters of `model_fn` don't match `params`. ValueError: if this is called via a subclass and if that class overrides a member of `Estimator`. """ + if context.in_eager_mode(): + raise RuntimeError( + 'Estimators are not supported when eager execution is enabled.') + Estimator._assert_members_are_not_overridden(self) if config is None: @@ -1016,4 +1026,3 @@ def _has_dataset_or_queue_runner(maybe_tensor): # Now, check queue. return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) -