From 453dd5848f5652f520eb0faf17a732f20779cdb1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 31 Oct 2017 14:26:27 -0700 Subject: [PATCH] K-FAC: Support for tf.AUTO_REUSE when re-using registrations. Multi-tower support for FullFB, NaiveDiagonalFB. Removal of LayerCollection.generic_registrations. PiperOrigin-RevId: 174092003 --- .../python/kernel_tests/fisher_blocks_test.py | 36 ++-- .../kernel_tests/layer_collection_test.py | 28 ++- .../contrib/kfac/python/ops/fisher_blocks.py | 38 +++- .../kfac/python/ops/layer_collection.py | 178 +++++++++++------- 4 files changed, 182 insertions(+), 98 deletions(-) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index dbf40fccc82..5f2b5c6cace 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -46,7 +46,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -54,7 +55,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -62,7 +64,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -71,7 +74,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) @@ -88,7 +92,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = params**2 block.instantiate_factors((grads,), 0.5) @@ -105,7 +110,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) @@ -131,7 +137,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -139,7 +146,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -147,7 +155,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -156,7 +165,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) @@ -173,7 +183,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = params**2 block.instantiate_factors((grads,), 0.5) @@ -189,7 +200,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index db7ab63c7d1..524e8338fde 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -30,6 +30,21 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +class MockFisherBlock(object): + """A fake FisherBlock.""" + + num_registered_minibatches = 2 + + def __init__(self, name='MockFisherBlock'): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockFisherBlock) and other.name == self.name + + def __hash__(self): + return hash(self.name) + + class LayerParametersDictTest(test.TestCase): def testSetItem(self): @@ -172,10 +187,12 @@ class LayerCollectionTest(test.TestCase): y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: '1', z: '2'} + lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - lc.register_block((x, y), 'foo') - self.assertEqual(set(['2', 'foo']), set(lc.get_blocks())) + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertEqual( + set([MockFisherBlock('2'), MockFisherBlock('foo')]), + set(lc.get_blocks())) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -438,11 +455,6 @@ class LayerCollectionTest(test.TestCase): def testGetUseCountMap(self): """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" - - class MockFisherBlock(object): - - num_registered_minibatches = 2 - lc = layer_collection.LayerCollection() lc.fisher_blocks = { 'a': MockFisherBlock(), diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index efffaaef8d5..a6fdf01fe7d 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -133,16 +133,15 @@ class FullFB(FisherBlock): to any type of parameter in principle, but has very high variance. """ - def __init__(self, layer_collection, params, batch_size): + def __init__(self, layer_collection, params): """Creates a FullFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters of this layer (Tensor or tuple of Tensors). - batch_size: The batch size, used in the covariance estimator. """ - self._batch_size = batch_size + self._batch_sizes = [] self._params = params super(FullFB, self).__init__(layer_collection) @@ -172,9 +171,21 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) class NaiveDiagonalFB(FisherBlock): @@ -186,17 +197,16 @@ class NaiveDiagonalFB(FisherBlock): to any type of parameter in principle, but has very high variance. """ - def __init__(self, layer_collection, params, batch_size): + def __init__(self, layer_collection, params): """Creates a NaiveDiagonalFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters of this layer (Tensor or tuple of Tensors). - batch_size: The batch size, used in the covariance estimator. """ self._params = params - self._batch_size = batch_size + self._batch_sizes = [] super(NaiveDiagonalFB, self).__init__(layer_collection) @@ -221,9 +231,21 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) class FullyConnectedDiagonalFB(FisherBlock): diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 1806f5d8651..4eabb59b3e4 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -103,10 +103,6 @@ class LayerCollection(object): fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer parameters (Tensors or tuples of Tensors) to FisherBlock instances. fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. - generic_registrations: a list of variables registered via a generic layer - registration. Generic registrations handle any and all of the ways a - variable is used in the graph, which means we don't need to check - their registration when verifying the correctness of the graph. losses: a list of LossFunction objects. The loss to be optimized is their sum. """ @@ -114,7 +110,6 @@ class LayerCollection(object): def __init__(self, graph=None, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() - self._generic_registrations = set() self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None @@ -127,7 +122,7 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) - def register_block(self, layer_key, fisher_block): + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. Validation consists of checking whether the key was already registered or @@ -153,20 +148,43 @@ class LayerCollection(object): layer_key: The key to check for in existing registrations and to register if valid. fisher_block: The associated fisher block. + reuse: Method to use for inserting new FisherBlocks. One of True, False, + or VARIABLE_SCOPE. Raises: ValueError: If the layer_key was already registered, or if a subset of the layer_key has already been registered as part of a different tuple. + + Returns: + FisherBlock registered under 'layer_key'. May or may not be the same as + 'fisher_block'. """ + if reuse is VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse is True or (reuse is variable_scope.AUTO_REUSE and + layer_key in self.fisher_blocks): + result = self.fisher_blocks[layer_key] + if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Attempted to register FisherBlock of type %s when existing " + "FisherBlock has type %s." % (type(fisher_block), type(result))) + return result + if reuse is False and layer_key in self.fisher_blocks: + raise ValueError("FisherBlock for %s is already in LayerCollection." % + (layer_key,)) + + # Insert fisher_block into self.fisher_blocks. if layer_key in self.fisher_blocks: raise ValueError("Duplicate registration: {}".format(layer_key)) if isinstance(layer_key, (tuple, list)): - self._register_block_with_sequence_key(layer_key, fisher_block) + return self._register_block_with_sequence_key(layer_key, fisher_block) else: - self._register_block_with_nonsequence_key(layer_key, fisher_block) + return self._register_block_with_nonsequence_key(layer_key, fisher_block) def _register_block_with_sequence_key(self, layer_key, fisher_block): """Validates and registers the layer_key if it's a sequence.""" + # Find all keys that are either supersets or subsets of 'layer_key'. inclusions = { fisher_elt for layer_elt in layer_key for fisher_elt in self.fisher_blocks @@ -175,24 +193,60 @@ class LayerCollection(object): if not inclusions: self.fisher_blocks[layer_key] = fisher_block - return + return fisher_block + result_key = None for key in inclusions: fisher_block_key = key if isinstance(key, (tuple, list)) else (key,) - if set(layer_key).issubset(fisher_block_key): - logging.warning("Graph Registration Warning: tried to register " - "a subset ({}) of an already registered tuple " - "({}), skipping".format(layer_key, fisher_block_key)) - return - if not set(fisher_block_key).issubset(layer_key): + in_existing_only = set(fisher_block_key) - set(layer_key) + in_new_only = set(layer_key) - set(fisher_block_key) + + if in_existing_only and in_new_only: + # Existing and new key have an intersection but neither is a subset of + # the other. This is an error. raise ValueError( "Inconsistent registration, expected new key to be a subset or " "superset of the existing key: existing is {}, new is {}".format( key, layer_key)) - else: + elif in_existing_only and not in_new_only: + # Existing key is strict superset of new key. Return existing + # FisherBlock. + logging.warning("Graph Registration Warning: tried to register " + "a subset ({}) of an already registered tuple " + "({}), skipping".format(layer_key, fisher_block_key)) + assert result_key is None + result_key = key + elif in_new_only and not in_existing_only: + # Existing key is a strict subset of new key. Replace existing + # FisherBlock with new one. + # + # TODO(b/68715045): This is dangerous. If there are existing + # registrations for a minibatch from elsewhere in the graph, they won't + # be re-registered with this new FisherBlock. The type of FisherBlock + # could also change here. + logging.warning( + "Replacing existing FisherBlock for key {} with new FisherBlock " + "for key {}. {} registered minibatches from the existing " + "FisherBlock will not be migrated.".format( + key, layer_key, + self.fisher_blocks[key].num_registered_minibatches)) self.fisher_blocks.pop(key) + self.fisher_blocks[layer_key] = fisher_block + assert result_key is None + result_key = layer_key + elif not in_new_only and not in_existing_only: + # Existing and new are identical. Reuse the old FisherBlock. + # + # TODO(b/68715045): This is dangerous. If the new FisherBlock has + # existing registered minibatches, they will not be migrated to the + # existing FisherBlock. + assert result_key is None + result_key = key + else: + raise ValueError("Unexpected layer key conflict: {} vs. {}".format( + layer_key, key)) - self.fisher_blocks[layer_key] = fisher_block + return self.fisher_blocks[result_key] def _register_block_with_nonsequence_key(self, layer_key, fisher_block): """Validates and registers the layer_key if it's not a sequence.""" @@ -209,6 +263,8 @@ class LayerCollection(object): "variable ({}) but a containing tuple was already " "registered ({}), skipping".format(layer_key, inclusions)) + return fisher_block + def _equal_or_subset(self, elt1, elt2): """Checks if the elements are equal or one is contained in the other.""" return (elt1 == elt2 or (isinstance(elt1, @@ -230,10 +286,6 @@ class LayerCollection(object): def get_factors(self): return self.fisher_factors.values() - @property - def generic_registrations(self): - return self._generic_registrations - @property def graph(self): return self._graph @@ -291,24 +343,7 @@ class LayerCollection(object): block_type = approx_to_block_types[approx] has_bias = isinstance(params, (tuple, list)) - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - block = self.fisher_blocks.get(params, None) - if block is None: - raise KeyError( - "Reuse requested but no FisherBlock found for params {}.".format( - params)) - if not isinstance(block, block_type): - raise ValueError( - "Requested block of type {} but block of type {} already exists " - "for params {}.".format(block_type, type(block), params)) - - else: - block = block_type(self, has_bias) - self.register_block(params, block) - + block = self.register_block(params, block_type(self, has_bias), reuse=reuse) block.register_additional_minibatch(inputs, outputs) def register_conv2d(self, @@ -351,42 +386,45 @@ class LayerCollection(object): raise ValueError("Bad value {} for approx.".format(approx)) block_type = approx_to_block_types[approx] - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - block = self.fisher_blocks.get(params, None) - if block is None: - raise KeyError( - "Reuse requested but no FisherBlock found for params {}.".format( - params)) - if not isinstance(block, block_type): - raise ValueError( - "Requested block of type {} but block of type {} already exists " - "for params {}.".format(block_type, type(block), params)) - - else: - block = block_type(self, params, strides, padding) - self.register_block(params, block) - + block = self.register_block( + params, block_type(self, params, strides, padding), reuse=reuse) block.register_additional_minibatch(inputs, outputs) - def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): - params = params if isinstance(params, (tuple, list)) else (params,) - self._generic_registrations |= set(params) + def register_generic(self, + params, + batch_size, + approx=APPROX_DIAGONAL_NAME, + reuse=VARIABLE_SCOPE): + """Registers a generic layer. - # Generic registrations do not need special registration rules because we do - # not care about multiple generic registrations. Add them to the - # fisher_block dictionary manually rather than going through the logic in - # self.register_block. - if approx == APPROX_FULL_NAME: - self.fisher_blocks[params] = fb.FullFB(self, params, batch_size) - elif approx == APPROX_DIAGONAL_NAME: - self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size) - else: + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + batch_size: 0-D Tensor. Size of the minibatch. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, + } + + if approx not in approx_to_block_types: raise ValueError("Bad value {} for approx.".format(approx)) + block_type = approx_to_block_types[approx] + block = self.register_block(params, block_type(self, params), reuse=reuse) + block.register_additional_minibatch(batch_size) + def register_categorical_predictive_distribution(self, logits, seed=None,