Modify quantization to support add ops that occur after Conv2D

PiperOrigin-RevId: 174058697
This commit is contained in:
A. Unique TensorFlower
2017-10-31 10:51:01 -07:00
committed by TensorFlower Gardener
parent 938643b561
commit 4699702601
2 changed files with 70 additions and 6 deletions

View File

@@ -28,7 +28,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training_util
# Operation types used to select oerations of interest.
# Operation types used to select operations of interest.
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
# Custom key for storing and retrieving update ops used by quantizing nodes.
@@ -83,12 +83,17 @@ def Quantize(graph,
for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)):
if op.name.endswith('/depthwise'):
# Separable convolution may consist of 2 convolution nodes. If so,
# skip .../depthwise and only quantize the top one.
# Separable convolution may consist of 2 convolution nodes. If so, skip
# .../depthwise and only quantize the top one.
separable_conv = context.GetOperationByNameDontThrow(
op.name[:-len('/depthwise')])
if separable_conv and separable_conv.type == 'Conv2D':
continue
if op.type == 'Conv2D':
# Quantize add ops that come after Conv2D
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
if add_context_re is not None:
context.add_contexts.add(add_context_re.group(1))
if not op.name.endswith('_Fold'):
folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold')
# Do nothing if found, it will be quantized when it is iterated over.
@@ -97,6 +102,8 @@ def Quantize(graph,
else:
context.QuantizeOpWithWeights(op, folded=True)
context.QuantizeAddContexts()
# Once all quantization ops have been inserted in the graph, collect update
# ops for their variables and modify the TF Slim update barrier (see
# https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py)
@@ -153,6 +160,22 @@ class _QuantizeContext(object):
self.is_training = is_training
self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
self.input_to_ops_map = input_to_ops.InputToOps(graph)
self.add_contexts = set()
def QuantizeAddContexts(self):
"""Quantizes all add ops in self.add_contexts."""
for add_context in self.add_contexts:
add_op = self.GetOperationByNamesDontThrow([
add_context + '/Add', add_context + '/add'])
if add_op is not None:
self._InsertQuantOp(
add_context,
add_op,
self.input_to_ops_map.ConsumerOperations(add_op),
name='add_quant',
moving_avg=True,
bits=self.activation_bits,
narrow_range=False)
def QuantizeOpWithWeights(self, op, folded):
"""Quantizes around the specific operation with or without batch norm.
@@ -219,7 +242,6 @@ class _QuantizeContext(object):
# When a bypass connection was found, also quantize Add op input.
if add_op:
def _QuantizeAddInput(add_input):
if folded:
return add_input.op.name.endswith('/add_fold')
@@ -267,7 +289,8 @@ class _QuantizeContext(object):
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
return consumers[0], None, None
if add_context:
add_op = self.GetOperationByNameDontThrow(add_context + '/Add')
add_op = self.GetOperationByNamesDontThrow([
add_context + '/Add', add_context + '/add'])
return activation_op, add_op, add_context
else:
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
@@ -280,13 +303,29 @@ class _QuantizeContext(object):
Returns:
The Operation with the given name. None if the name does not correspond to
any operation in the graph
any operation in the graph.
"""
try:
return self.graph.get_operation_by_name(name)
except KeyError:
return None
def GetOperationByNamesDontThrow(self, names):
"""Returns an Operation with one of the given names.
Args:
names: Names of Operation to return.
Returns:
The Operation with one of the given names. None if none of the names
corresponds to any operation in the graph.
"""
for name in names:
op = self.GetOperationByNameDontThrow(name)
if op is not None:
return op
return None
def _InsertQuantOp(
self,
context,

View File

@@ -23,7 +23,9 @@ from tensorflow.contrib.quantize.python import quantize
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@@ -52,6 +54,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertEqual(
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
def testInsertQuantOpForAddAfterConv2d(self):
graph = ops.Graph()
with graph.as_default():
batch_size, height, width, depth = 5, 128, 128, 3
input1 = array_ops.zeros((batch_size, height, width, depth))
input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
conv = conv2d(input1, 32, [5, 5], stride=2, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None, scope='test/test')
node = math_ops.add(conv, input2, name='test/add')
node = array_ops.identity(node, name='test/identity')
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
add_quant = graph.get_operation_by_name('test/add_quant/' +
quantization_node_name)
self.assertEqual(add_quant.type, quantization_node_name)
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.