From b3f33ad46613834922f381f816715fdcfee54b87 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Tue, 20 Jun 2017 20:01:39 -0700 Subject: [PATCH] Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- tensorflow/contrib/keras/BUILD | 6 +- .../python/keras/layers/normalization_test.py | 29 ++++---- .../contrib/layers/python/layers/layers.py | 70 ++++++++++++------- .../layers/python/layers/layers_test.py | 41 +++++------ .../contrib/slim/python/slim/learning_test.py | 8 ++- tensorflow/python/layers/normalization.py | 11 ++- .../python/layers/normalization_test.py | 69 +++++++++--------- 7 files changed, 135 insertions(+), 99 deletions(-) diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 71ce6540d62..619ebb7ce07 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -7,6 +7,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") py_library( @@ -393,12 +394,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "normalization_test", size = "small", srcs = ["python/keras/layers/normalization_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":keras", ":testing_utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py index dc410f84d85..1a0686800eb 100644 --- a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py +++ b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py @@ -94,22 +94,23 @@ class NoiseLayersTest(test.TestCase): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) def test_batchnorm_convnet(self): - with self.test_session(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization( - axis=1, input_shape=(3, 4, 4), momentum=0.8) - model.add(norm) - model.compile(loss='mse', optimizer='sgd') + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + axis=1, input_shape=(3, 4, 4), momentum=0.8) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') - # centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4)) - model.fit(x, x, epochs=4, verbose=0) - out = model.predict(x) - out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1)) - out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1)) + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4)) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1)) + out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1)) - np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1) - np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1) + np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1) def test_shared_batchnorm(self): """Test that a BN layer can be shared across different data streams. diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 46c355f08a9..f2a904b5211 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -257,27 +257,33 @@ def _fused_batch_norm( 'beta') if not param_initializers: param_initializers = {} - beta_initializer = param_initializers.get('beta', - init_ops.zeros_initializer()) - beta = variables.model_variable( - 'beta', - shape=params_shape, - dtype=dtype, - initializer=beta_initializer, - collections=beta_collections, - trainable=trainable_beta) - trainable_gamma = trainable and scale - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') - gamma_initializer = param_initializers.get('gamma', - init_ops.ones_initializer()) - gamma = variables.model_variable( - 'gamma', - shape=params_shape, - dtype=dtype, - initializer=gamma_initializer, - collections=gamma_collections, - trainable=trainable_gamma) + if center: + beta_initializer = param_initializers.get('beta', + init_ops.zeros_initializer()) + beta = variables.model_variable( + 'beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable_beta) + else: + beta = array_ops.constant(0.0, shape=params_shape) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get('gamma', + init_ops.ones_initializer()) + gamma = variables.model_variable( + 'gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + else: + gamma = array_ops.constant(1.0, shape=params_shape) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. @@ -449,7 +455,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -473,7 +480,6 @@ def batch_norm(inputs, Raises: ValueError: If `batch_weights` is not None and `fused` is True. - ValueError: If `param_regularizers` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. @@ -487,6 +493,21 @@ def batch_norm(inputs, 'supported for fused batch norm.') if renorm: raise ValueError('Renorm is not supported for fused batch norm.') + + # Only use _fused_batch_norm (1) if fused is set True or if it is + # possible to use (currently it doesn't support batch weights, + # renorm, and the case when rank is neither 2 nor 4), + # and (2) if used with zero_debias_moving_mean, or an input shape of rank 2, + # or non-default updates_collections (not implemented in + # normalization_layers.BatchNormalization yet); otherwise use the fused + # implementation in normalization_layers.BatchNormalization. + inputs = ops.convert_to_tensor(inputs) + rank = inputs.get_shape().ndims + feature_supported = batch_weights is None and not renorm and rank in [2, 4] + possible_to_fuse = fused is None and feature_supported + if (fused or possible_to_fuse) and ( + zero_debias_moving_mean or rank == 2 or + updates_collections is not ops.GraphKeys.UPDATE_OPS): return _fused_batch_norm( inputs, decay=decay, @@ -552,7 +573,8 @@ def batch_norm(inputs, renorm_momentum=renorm_decay, name=sc.name, _scope=sc, - _reuse=reuse) + _reuse=reuse, + fused=fused) outputs = layer.apply(inputs, training=is_training) # Add variables to collections. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 67f45473e84..d4ee85b5501 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1703,13 +1703,6 @@ class BatchNormTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'): _layers.batch_norm(inputs, batch_weights=batch_weights, fused=True) - def testParamRegularizersFused(self): - with ops.Graph().as_default() as g, self.test_session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7)) - with self.assertRaisesRegexp(ValueError, - 'Regularizers are not currently'): - _layers.batch_norm(inputs, param_regularizers={}, fused=True) - def _testCreateOp(self, fused): height, width = 3, 3 with self.test_session(): @@ -1780,7 +1773,8 @@ class BatchNormTest(test.TestCase): height, width = 3, 3 with self.test_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) - _layers.batch_norm(images, scale=True, zero_debias_moving_mean=True) + _layers.batch_norm( + images, scale=True, zero_debias_moving_mean=True, fused=False) self.assertEqual(len(variables.get_model_variables()), 6) moving_mean = variables.get_variables_by_name('moving_mean')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0] @@ -1874,7 +1868,8 @@ class BatchNormTest(test.TestCase): images, decay=0.1, updates_collections=None, - zero_debias_moving_mean=True) + zero_debias_moving_mean=True, + fused=False) moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0] biased = variables.get_variables_by_name('biased')[0] @@ -2523,7 +2518,7 @@ class BatchNormTest(test.TestCase): def _runBatchNormalizationWithFormat(self, shape, data_format, is_training): channels = shape[-1] - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) beta = init_ops.constant_initializer( np.arange( @@ -2561,20 +2556,22 @@ class BatchNormTest(test.TestCase): return sess.run(output) def testNHWCAndNCHWInferenceProduceSameOutput(self): - for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: - nhwc = self._runBatchNormalizationWithFormat( - data_format='NHWC', shape=shape, is_training=False) - nchw = self._runBatchNormalizationWithFormat( - data_format='NCHW', shape=shape, is_training=False) - self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) + if test.is_gpu_available(cuda_only=True): + for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: + nhwc = self._runBatchNormalizationWithFormat( + data_format='NHWC', shape=shape, is_training=False) + nchw = self._runBatchNormalizationWithFormat( + data_format='NCHW', shape=shape, is_training=False) + self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) def testNHWCAndNCHWTrainingProduceSameOutput(self): - for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: - nhwc = self._runBatchNormalizationWithFormat( - data_format='NHWC', shape=shape, is_training=True) - nchw = self._runBatchNormalizationWithFormat( - data_format='NCHW', shape=shape, is_training=True) - self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) + if test.is_gpu_available(cuda_only=True): + for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: + nhwc = self._runBatchNormalizationWithFormat( + data_format='NHWC', shape=shape, is_training=True) + nchw = self._runBatchNormalizationWithFormat( + data_format='NCHW', shape=shape, is_training=True) + self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) class LayerNormTest(test.TestCase): diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index 83d45f6f5ad..69061460eb6 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -220,7 +220,7 @@ def LogisticClassifier(inputs): def BatchNormClassifier(inputs): - inputs = layers.batch_norm(inputs, decay=0.1) + inputs = layers.batch_norm(inputs, decay=0.1, fused=None) return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) @@ -267,6 +267,11 @@ class CreateTrainOpTest(test.TestCase): self._inputs = np.random.rand(16, 4).astype(np.float32) self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) + def _addBesselsCorrection(self, sample_size, expected_var): + correction_factor = sample_size / (sample_size - 1) + expected_var *= correction_factor + return expected_var + def testUseUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) @@ -275,6 +280,7 @@ class CreateTrainOpTest(test.TestCase): expected_mean = np.mean(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0)) + expected_var = self._addBesselsCorrection(16, expected_var) tf_predictions = BatchNormClassifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 780d1c2b8e0..ad0f202f959 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -123,6 +123,10 @@ class BatchNormalization(base.Layer): if self.fused and renorm: raise ValueError( 'Batch renorm is currently not supported with fused batch norm.') + if self.fused and (beta_regularizer is not None or + gamma_regularizer is not None): + raise ValueError('Regularizers are not currently ' + 'supported for fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -153,7 +157,12 @@ class BatchNormalization(base.Layer): ' is out of range for input with rank ' + str(ndim)) if self.fused is None: - self.fused = not self.renorm and ndim == 4 and axis in [1, 3] + # Currently fused batch norm doesn't support renorm and beta/gamma + # regularizer; and only supports an input tensor of rank 4 and a channel + # dimension on axis 1 and 3. + self.fused = not self.renorm and ndim == 4 and axis in [ + 1, 3 + ] and self.beta_regularizer is None and self.gamma_regularizer is None if self.fused: if axis == 1: diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index fa6c9c4a5db..64bebb1021c 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -143,45 +143,46 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) def test4DInputAxis1(self): - epsilon = 1e-3 - bn = normalization_layers.BatchNormalization( - axis=1, epsilon=epsilon, momentum=0.9) - inputs = variables.Variable( - np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) - training = array_ops.placeholder(dtype='bool') - outputs = bn.apply(inputs, training=training) + if test.is_gpu_available(cuda_only=True): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=1, epsilon=epsilon, momentum=0.9) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) - with self.test_session() as sess: - # Test training with placeholder learning phase. - sess.run(variables.global_variables_initializer()) - np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) - np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) - np_beta = np.reshape(np_beta, (1, 4, 1, 1)) - for _ in range(100): - np_output, _, _ = sess.run([outputs] + bn.updates, - feed_dict={training: True}) - # Verify that the axis is normalized during training. + with self.test_session(use_gpu=True) as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) - # Verify that the statistics are updated during training. - moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) - np_inputs = sess.run(inputs) - mean = np.mean(np_inputs, axis=(0, 2, 3)) - std = np.std(np_inputs, axis=(0, 2, 3)) - variance = np.square(std) - self.assertAllClose(mean, moving_mean, atol=1e-2) - self.assertAllClose(variance, moving_var, atol=1e-2) - - # Test inference with placeholder learning phase. - np_output = sess.run(outputs, feed_dict={training: False}) - - # Verify that the axis is normalized during inference. - normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta - self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) - self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) - def test4DInputAxis2(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization(