mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
K-FAC: Multiple minibatches support for LayerCollection.register_conv2d()
PiperOrigin-RevId: 173941279
This commit is contained in:
committed by
TensorFlower Gardener
parent
32f3c3a431
commit
35cc8bb0a2
@@ -329,72 +329,82 @@ class LayerCollectionTest(test.TestCase):
|
||||
single_loss = sess.run(lc.total_loss())
|
||||
self.assertAlmostEqual(7.6983433, single_loss)
|
||||
|
||||
def ensureLayerReuseWorks(self, register_fn):
|
||||
"""Ensure the 'reuse' keyword argument function as intended.
|
||||
|
||||
Args:
|
||||
register_fn: function for registering a layer. Arguments are
|
||||
layer_collection, reuse, and approx.
|
||||
"""
|
||||
# Fails on second if reuse=False.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, reuse=False)
|
||||
|
||||
# Succeeds on second if reuse=True.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
register_fn(lc, reuse=True)
|
||||
|
||||
# Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
# Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with variable_scope.variable_scope(
|
||||
variable_scope.get_variable_scope(), reuse=True):
|
||||
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
# Fails if block type changes.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
|
||||
|
||||
# Fails if reuse requested but no FisherBlock exists.
|
||||
lc = layer_collection.LayerCollection()
|
||||
with self.assertRaises(KeyError):
|
||||
register_fn(lc, reuse=True)
|
||||
|
||||
def testRegisterFullyConnectedReuse(self):
|
||||
"""Ensure the 'reuse' keyword argument function as intended."""
|
||||
"""Ensure the 'reuse' works with register_fully_connected."""
|
||||
with ops.Graph().as_default():
|
||||
inputs = [
|
||||
array_ops.ones([2, 10]), #
|
||||
array_ops.zeros([5, 10])
|
||||
]
|
||||
outputs = [
|
||||
array_ops.zeros([2, 5]), #
|
||||
array_ops.ones([5, 5])
|
||||
]
|
||||
inputs = array_ops.ones([2, 10])
|
||||
outputs = array_ops.zeros([2, 5])
|
||||
params = (
|
||||
variable_scope.get_variable('w', [10, 5]), #
|
||||
variable_scope.get_variable('b', [5]))
|
||||
|
||||
# Fails on second if reuse=False.
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(params, inputs[0], outputs[0])
|
||||
with self.assertRaises(ValueError):
|
||||
lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False)
|
||||
|
||||
# Succeeds on second if reuse=True.
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(params, inputs[0], outputs[0])
|
||||
lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True)
|
||||
|
||||
# Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(params, inputs[0], outputs[0])
|
||||
with self.assertRaises(ValueError):
|
||||
def register_fn(lc, **kwargs):
|
||||
lc.register_fully_connected(
|
||||
params,
|
||||
inputs[1],
|
||||
outputs[1],
|
||||
reuse=layer_collection.VARIABLE_SCOPE)
|
||||
params=params, inputs=inputs, outputs=outputs, **kwargs)
|
||||
|
||||
# Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(params, inputs[0], outputs[0])
|
||||
with variable_scope.variable_scope(
|
||||
variable_scope.get_variable_scope(), reuse=True):
|
||||
lc.register_fully_connected(
|
||||
params,
|
||||
inputs[1],
|
||||
outputs[1],
|
||||
reuse=layer_collection.VARIABLE_SCOPE)
|
||||
self.ensureLayerReuseWorks(register_fn)
|
||||
|
||||
# Fails if block type changes.
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(
|
||||
params,
|
||||
inputs[0],
|
||||
outputs[0],
|
||||
approx=layer_collection.APPROX_KRONECKER_NAME)
|
||||
with self.assertRaises(ValueError):
|
||||
lc.register_fully_connected(
|
||||
params,
|
||||
inputs[1],
|
||||
outputs[1],
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME,
|
||||
reuse=True)
|
||||
def testRegisterConv2dReuse(self):
|
||||
"""Ensure the 'reuse' works with register_conv2d."""
|
||||
with ops.Graph().as_default():
|
||||
inputs = array_ops.ones([2, 5, 5, 10])
|
||||
outputs = array_ops.zeros([2, 5, 5, 3])
|
||||
params = (
|
||||
variable_scope.get_variable('w', [1, 1, 10, 3]), #
|
||||
variable_scope.get_variable('b', [3]))
|
||||
|
||||
# Fails if reuse requested but no FisherBlock exists.
|
||||
lc = layer_collection.LayerCollection()
|
||||
with self.assertRaises(KeyError):
|
||||
lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True)
|
||||
def register_fn(lc, **kwargs):
|
||||
lc.register_conv2d(
|
||||
params=params,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
**kwargs)
|
||||
|
||||
self.ensureLayerReuseWorks(register_fn)
|
||||
|
||||
def testMakeOrGetFactor(self):
|
||||
with ops.Graph().as_default():
|
||||
|
||||
@@ -311,18 +311,67 @@ class LayerCollection(object):
|
||||
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
def register_conv2d(self, params, strides, padding, inputs, outputs,
|
||||
approx=APPROX_KRONECKER_NAME):
|
||||
def register_conv2d(self,
|
||||
params,
|
||||
strides,
|
||||
padding,
|
||||
inputs,
|
||||
outputs,
|
||||
approx=APPROX_KRONECKER_NAME,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a convolutional layer.
|
||||
|
||||
if approx == APPROX_KRONECKER_NAME:
|
||||
block = fb.ConvKFCBasicFB(self, params, strides, padding)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.register_block(params, block)
|
||||
elif approx == APPROX_DIAGONAL_NAME:
|
||||
block = fb.ConvDiagonalFB(self, params, strides, padding)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
Args:
|
||||
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
|
||||
this layer. Weight matrix should have shape [kernel_height,
|
||||
kernel_width, in_channels, out_channels]. Bias should have shape
|
||||
[out_channels].
|
||||
strides: 1-D Tensor of length 4. Strides for convolution kernel.
|
||||
padding: string. see tf.nn.conv2d for valid values.
|
||||
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
|
||||
to layer.
|
||||
outputs: Tensor of shape [batch_size, height, width, out_channels].
|
||||
Preactivations produced by layer.
|
||||
approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
|
||||
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
|
||||
create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
|
||||
Raises:
|
||||
ValueError: For improper value to 'approx'.
|
||||
KeyError: If reuse == True but no FisherBlock found for 'params'.
|
||||
ValueError: If reuse == True and FisherBlock found but of the wrong type.
|
||||
"""
|
||||
approx_to_block_types = {
|
||||
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
|
||||
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
|
||||
}
|
||||
|
||||
if approx not in approx_to_block_types:
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
|
||||
if reuse == VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse:
|
||||
block = self.fisher_blocks.get(params, None)
|
||||
if block is None:
|
||||
raise KeyError(
|
||||
"Reuse requested but no FisherBlock found for params {}.".format(
|
||||
params))
|
||||
if not isinstance(block, block_type):
|
||||
raise ValueError(
|
||||
"Requested block of type {} but block of type {} already exists "
|
||||
"for params {}.".format(block_type, type(block), params))
|
||||
|
||||
else:
|
||||
block = block_type(self, params, strides, padding)
|
||||
self.register_block(params, block)
|
||||
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
|
||||
params = params if isinstance(params, (tuple, list)) else (params,)
|
||||
self._generic_registrations |= set(params)
|
||||
|
||||
Reference in New Issue
Block a user