diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index c524da43093..c328b037078 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -92,8 +92,7 @@ def _count_condition(values, or tuple. """ check_ops.assert_type(values, dtypes.bool) - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') values = math_ops.to_float(values) if weights is not None: @@ -916,8 +915,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tp' in includes: true_positives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_positives') + [num_thresholds], dtypes.float32, name='true_positives') is_true_positive = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: @@ -929,8 +927,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fn' in includes: false_negatives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_negatives') + [num_thresholds], dtypes.float32, name='false_negatives') is_false_negative = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: @@ -942,8 +939,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tn' in includes: true_negatives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_negatives') + [num_thresholds], dtypes.float32, name='true_negatives') is_true_negative = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: @@ -955,8 +951,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fp' in includes: false_positives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_positives') + [num_thresholds], dtypes.float32, name='false_positives') is_false_positive = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: @@ -1317,9 +1312,9 @@ def streaming_precision_recall_at_equal_thresholds(predictions, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtype), name='tp_buckets') + [num_thresholds], dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtype), name='fp_buckets') + [num_thresholds], dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -2582,15 +2577,13 @@ def streaming_covariance(predictions, predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') mean_prediction = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='mean_prediction') + [], dtypes.float32, name='mean_prediction') mean_label = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='mean_label') + [], dtypes.float32, name='mean_label') comoment = metrics_impl.metric_variable( # C_A in update equation - array_ops.zeros([], dtype=dtypes.float32), - name='comoment') + [], dtypes.float32, name='comoment') if weights is None: batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn @@ -3011,11 +3004,8 @@ def streaming_concat(values, init_size = 0 if max_size is None else max_size init_shape = [init_size] + fixed_shape array = metrics_impl.metric_variable( - array_ops.zeros(init_shape, dtype=values.dtype), - validate_shape=False, - name='array') - size = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.int32), name='size') + init_shape, values.dtype, validate_shape=False, name='array') + size = metrics_impl.metric_variable([], dtypes.int32, name='size') perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] valid_array = array[:size] @@ -3149,8 +3139,7 @@ def count(values, """ with variable_scope.variable_scope(name, 'count', (values, weights)): - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: num_values = math_ops.to_float(array_ops.size(values)) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index ce7fbe33315..b9965dba87f 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -35,18 +35,11 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops -def metric_variable(initial_value, validate_shape=True, name=None): - """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections. +def metric_variable(shape, dtype, validate_shape=True, name=None): + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" - Args: - initial_value: See variables.Variable.__init__. - validate_shape: See variables.Variable.__init__. - name: See variables.Variable.__init__. - Returns: - New variable. - """ return variable_scope.variable( - initial_value, + lambda: array_ops.zeros(shape, dtype), trainable=False, collections=[ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES @@ -244,8 +237,7 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): """ # Local variable to accumulate the predictions in the confusion matrix. total_cm = metric_variable( - array_ops.zeros([num_classes, num_classes], dtype=dtypes.float64), - name='total_confusion_matrix') + [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') # Cast the type to int64 required by confusion_matrix_ops. predictions = math_ops.to_int64(predictions) @@ -315,10 +307,8 @@ def mean(values, weights=None, metrics_collections=None, with variable_scope.variable_scope(name, 'mean', (values, weights)): values = math_ops.to_float(values) - total = metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='total') - count = metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + total = metric_variable([], dtypes.float32, name='total') + count = metric_variable([], dtypes.float32, name='count') if weights is None: num_values = math_ops.to_float(array_ops.size(values)) @@ -516,8 +506,7 @@ def _confusion_matrix_at_thresholds( if 'tp' in includes: true_p = metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_positives') + [num_thresholds], dtypes.float32, name='true_positives') is_true_positive = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: @@ -528,8 +517,7 @@ def _confusion_matrix_at_thresholds( if 'fn' in includes: false_n = metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_negatives') + [num_thresholds], dtypes.float32, name='false_negatives') is_false_negative = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: @@ -540,8 +528,7 @@ def _confusion_matrix_at_thresholds( if 'tn' in includes: true_n = metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_negatives') + [num_thresholds], dtypes.float32, name='true_negatives') is_true_negative = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: @@ -552,8 +539,7 @@ def _confusion_matrix_at_thresholds( if 'fp' in includes: false_p = metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_positives') + [num_thresholds], dtypes.float32, name='false_positives') is_false_positive = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: @@ -1183,11 +1169,9 @@ def mean_tensor(values, weights=None, metrics_collections=None, with variable_scope.variable_scope(name, 'mean', (values, weights)): values = math_ops.to_float(values) total = metric_variable( - array_ops.zeros(values.get_shape(), dtype=dtypes.float32), - name='total_tensor') + values.get_shape(), dtypes.float32, name='total_tensor') count = metric_variable( - array_ops.zeros(values.get_shape(), dtype=dtypes.float32), - name='count_tensor') + values.get_shape(), dtypes.float32, name='count_tensor') num_values = array_ops.ones_like(values) if weights is not None: @@ -1300,8 +1284,7 @@ def _count_condition(values, weights=None, metrics_collections=None, or tuple. """ check_ops.assert_type(values, dtypes.bool) - count = metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count = metric_variable([], dtypes.float32, name='count') values = math_ops.to_float(values) if weights is not None: @@ -2082,7 +2065,7 @@ def _streaming_sparse_true_positive_at_k(labels, weights=weights) batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) - var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope) + var = metric_variable([], dtypes.float64, name=scope) return var, state_ops.assign_add(var, batch_total_tp, name='update') @@ -2178,7 +2161,7 @@ def _streaming_sparse_false_negative_at_k(labels, weights=weights) batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) - var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope) + var = metric_variable([], dtypes.float64, name=scope) return var, state_ops.assign_add(var, batch_total_fn, name='update') @@ -2829,8 +2812,7 @@ def _streaming_sparse_average_precision_at_top_k(labels, # - For the unweighted case, this is just the number of rows. # - For the weighted case, it's the sum of the weights broadcast across # `average_precision` rows. - max_var = metric_variable( - array_ops.zeros([], dtype=dtypes.float64), name=max_scope) + max_var = metric_variable([], dtypes.float64, name=max_scope) if weights is None: batch_max = math_ops.to_double( array_ops.size(average_precision, name='batch_max')) @@ -2838,8 +2820,7 @@ def _streaming_sparse_average_precision_at_top_k(labels, batch_max = math_ops.reduce_sum(weights, name='batch_max') max_update = state_ops.assign_add(max_var, batch_max, name='update') with ops.name_scope(None, 'total', (average_precision,)) as total_scope: - total_var = metric_variable( - array_ops.zeros([], dtype=dtypes.float64), name=total_scope) + total_var = metric_variable([], dtypes.float64, name=total_scope) batch_total = math_ops.reduce_sum(average_precision, name='batch_total') total_update = state_ops.assign_add(total_var, batch_total, name='update') @@ -3025,7 +3006,7 @@ def _streaming_sparse_false_positive_at_k(labels, weights=weights) batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) - var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope) + var = metric_variable([], dtypes.float64, name=scope) return var, state_ops.assign_add(var, batch_total_fp, name='update')