From 35cc8bb0a292091a20d821644bf9732f6a98d2f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Oct 2017 13:34:59 -0700 Subject: [PATCH] K-FAC: Multiple minibatches support for LayerCollection.register_conv2d() PiperOrigin-RevId: 173941279 --- .../kernel_tests/layer_collection_test.py | 122 ++++++++++-------- .../kfac/python/ops/layer_collection.py | 67 ++++++++-- 2 files changed, 124 insertions(+), 65 deletions(-) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 4f27ceced98..db7ab63c7d1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -329,72 +329,82 @@ class LayerCollectionTest(test.TestCase): single_loss = sess.run(lc.total_loss()) self.assertAlmostEqual(7.6983433, single_loss) + def ensureLayerReuseWorks(self, register_fn): + """Ensure the 'reuse' keyword argument function as intended. + + Args: + register_fn: function for registering a layer. Arguments are + layer_collection, reuse, and approx. + """ + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + register_fn(lc) + register_fn(lc, reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + register_fn(lc, reuse=True) + def testRegisterFullyConnectedReuse(self): - """Ensure the 'reuse' keyword argument function as intended.""" + """Ensure the 'reuse' works with register_fully_connected.""" with ops.Graph().as_default(): - inputs = [ - array_ops.ones([2, 10]), # - array_ops.zeros([5, 10]) - ] - outputs = [ - array_ops.zeros([2, 5]), # - array_ops.ones([5, 5]) - ] + inputs = array_ops.ones([2, 10]) + outputs = array_ops.zeros([2, 5]) params = ( variable_scope.get_variable('w', [10, 5]), # variable_scope.get_variable('b', [5])) - # Fails on second if reuse=False. - lc = layer_collection.LayerCollection() - lc.register_fully_connected(params, inputs[0], outputs[0]) - with self.assertRaises(ValueError): - lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False) - - # Succeeds on second if reuse=True. - lc = layer_collection.LayerCollection() - lc.register_fully_connected(params, inputs[0], outputs[0]) - lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True) - - # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. - lc = layer_collection.LayerCollection() - lc.register_fully_connected(params, inputs[0], outputs[0]) - with self.assertRaises(ValueError): + def register_fn(lc, **kwargs): lc.register_fully_connected( - params, - inputs[1], - outputs[1], - reuse=layer_collection.VARIABLE_SCOPE) + params=params, inputs=inputs, outputs=outputs, **kwargs) - # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. - lc = layer_collection.LayerCollection() - lc.register_fully_connected(params, inputs[0], outputs[0]) - with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=True): - lc.register_fully_connected( - params, - inputs[1], - outputs[1], - reuse=layer_collection.VARIABLE_SCOPE) + self.ensureLayerReuseWorks(register_fn) - # Fails if block type changes. - lc = layer_collection.LayerCollection() - lc.register_fully_connected( - params, - inputs[0], - outputs[0], - approx=layer_collection.APPROX_KRONECKER_NAME) - with self.assertRaises(ValueError): - lc.register_fully_connected( - params, - inputs[1], - outputs[1], - approx=layer_collection.APPROX_DIAGONAL_NAME, - reuse=True) + def testRegisterConv2dReuse(self): + """Ensure the 'reuse' works with register_conv2d.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + params = ( + variable_scope.get_variable('w', [1, 1, 10, 3]), # + variable_scope.get_variable('b', [3])) - # Fails if reuse requested but no FisherBlock exists. - lc = layer_collection.LayerCollection() - with self.assertRaises(KeyError): - lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True) + def register_fn(lc, **kwargs): + lc.register_conv2d( + params=params, + strides=[1, 1, 1, 1], + padding='SAME', + inputs=inputs, + outputs=outputs, + **kwargs) + + self.ensureLayerReuseWorks(register_fn) def testMakeOrGetFactor(self): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 77ddd19e59a..1806f5d8651 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -311,18 +311,67 @@ class LayerCollection(object): block.register_additional_minibatch(inputs, outputs) - def register_conv2d(self, params, strides, padding, inputs, outputs, - approx=APPROX_KRONECKER_NAME): + def register_conv2d(self, + params, + strides, + padding, + inputs, + outputs, + approx=APPROX_KRONECKER_NAME, + reuse=VARIABLE_SCOPE): + """Registers a convolutional layer. - if approx == APPROX_KRONECKER_NAME: - block = fb.ConvKFCBasicFB(self, params, strides, padding) - block.register_additional_minibatch(inputs, outputs) - self.register_block(params, block) - elif approx == APPROX_DIAGONAL_NAME: - block = fb.ConvDiagonalFB(self, params, strides, padding) - block.register_additional_minibatch(inputs, outputs) + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + outputs: Tensor of shape [batch_size, height, width, out_channels]. + Preactivations produced by layer. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, + } + + if approx not in approx_to_block_types: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = approx_to_block_types[approx] + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + block = self.fisher_blocks.get(params, None) + if block is None: + raise KeyError( + "Reuse requested but no FisherBlock found for params {}.".format( + params)) + if not isinstance(block, block_type): + raise ValueError( + "Requested block of type {} but block of type {} already exists " + "for params {}.".format(block_type, type(block), params)) + + else: + block = block_type(self, params, strides, padding) self.register_block(params, block) + block.register_additional_minibatch(inputs, outputs) + def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): params = params if isinstance(params, (tuple, list)) else (params,) self._generic_registrations |= set(params)