mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Modify quantization to support add ops that occur after Conv2D
PiperOrigin-RevId: 174058697
This commit is contained in:
committed by
TensorFlower Gardener
parent
938643b561
commit
4699702601
@@ -28,7 +28,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
# Operation types used to select oerations of interest.
|
||||
# Operation types used to select operations of interest.
|
||||
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
|
||||
|
||||
# Custom key for storing and retrieving update ops used by quantizing nodes.
|
||||
@@ -83,12 +83,17 @@ def Quantize(graph,
|
||||
|
||||
for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)):
|
||||
if op.name.endswith('/depthwise'):
|
||||
# Separable convolution may consist of 2 convolution nodes. If so,
|
||||
# skip .../depthwise and only quantize the top one.
|
||||
# Separable convolution may consist of 2 convolution nodes. If so, skip
|
||||
# .../depthwise and only quantize the top one.
|
||||
separable_conv = context.GetOperationByNameDontThrow(
|
||||
op.name[:-len('/depthwise')])
|
||||
if separable_conv and separable_conv.type == 'Conv2D':
|
||||
continue
|
||||
if op.type == 'Conv2D':
|
||||
# Quantize add ops that come after Conv2D
|
||||
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
|
||||
if add_context_re is not None:
|
||||
context.add_contexts.add(add_context_re.group(1))
|
||||
if not op.name.endswith('_Fold'):
|
||||
folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold')
|
||||
# Do nothing if found, it will be quantized when it is iterated over.
|
||||
@@ -97,6 +102,8 @@ def Quantize(graph,
|
||||
else:
|
||||
context.QuantizeOpWithWeights(op, folded=True)
|
||||
|
||||
context.QuantizeAddContexts()
|
||||
|
||||
# Once all quantization ops have been inserted in the graph, collect update
|
||||
# ops for their variables and modify the TF Slim update barrier (see
|
||||
# https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py)
|
||||
@@ -153,6 +160,22 @@ class _QuantizeContext(object):
|
||||
self.is_training = is_training
|
||||
self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
|
||||
self.input_to_ops_map = input_to_ops.InputToOps(graph)
|
||||
self.add_contexts = set()
|
||||
|
||||
def QuantizeAddContexts(self):
|
||||
"""Quantizes all add ops in self.add_contexts."""
|
||||
for add_context in self.add_contexts:
|
||||
add_op = self.GetOperationByNamesDontThrow([
|
||||
add_context + '/Add', add_context + '/add'])
|
||||
if add_op is not None:
|
||||
self._InsertQuantOp(
|
||||
add_context,
|
||||
add_op,
|
||||
self.input_to_ops_map.ConsumerOperations(add_op),
|
||||
name='add_quant',
|
||||
moving_avg=True,
|
||||
bits=self.activation_bits,
|
||||
narrow_range=False)
|
||||
|
||||
def QuantizeOpWithWeights(self, op, folded):
|
||||
"""Quantizes around the specific operation with or without batch norm.
|
||||
@@ -219,7 +242,6 @@ class _QuantizeContext(object):
|
||||
|
||||
# When a bypass connection was found, also quantize Add op input.
|
||||
if add_op:
|
||||
|
||||
def _QuantizeAddInput(add_input):
|
||||
if folded:
|
||||
return add_input.op.name.endswith('/add_fold')
|
||||
@@ -267,7 +289,8 @@ class _QuantizeContext(object):
|
||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
||||
return consumers[0], None, None
|
||||
if add_context:
|
||||
add_op = self.GetOperationByNameDontThrow(add_context + '/Add')
|
||||
add_op = self.GetOperationByNamesDontThrow([
|
||||
add_context + '/Add', add_context + '/add'])
|
||||
return activation_op, add_op, add_context
|
||||
else:
|
||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
||||
@@ -280,13 +303,29 @@ class _QuantizeContext(object):
|
||||
|
||||
Returns:
|
||||
The Operation with the given name. None if the name does not correspond to
|
||||
any operation in the graph
|
||||
any operation in the graph.
|
||||
"""
|
||||
try:
|
||||
return self.graph.get_operation_by_name(name)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def GetOperationByNamesDontThrow(self, names):
|
||||
"""Returns an Operation with one of the given names.
|
||||
|
||||
Args:
|
||||
names: Names of Operation to return.
|
||||
|
||||
Returns:
|
||||
The Operation with one of the given names. None if none of the names
|
||||
corresponds to any operation in the graph.
|
||||
"""
|
||||
for name in names:
|
||||
op = self.GetOperationByNameDontThrow(name)
|
||||
if op is not None:
|
||||
return op
|
||||
return None
|
||||
|
||||
def _InsertQuantOp(
|
||||
self,
|
||||
context,
|
||||
|
||||
@@ -23,7 +23,9 @@ from tensorflow.contrib.quantize.python import quantize
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@@ -52,6 +54,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(
|
||||
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
|
||||
|
||||
def testInsertQuantOpForAddAfterConv2d(self):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
batch_size, height, width, depth = 5, 128, 128, 3
|
||||
input1 = array_ops.zeros((batch_size, height, width, depth))
|
||||
input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
|
||||
conv = conv2d(input1, 32, [5, 5], stride=2, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None, scope='test/test')
|
||||
node = math_ops.add(conv, input2, name='test/add')
|
||||
node = array_ops.identity(node, name='test/identity')
|
||||
update_barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
|
||||
activation_bits=8)
|
||||
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
||||
quantization_node_name)
|
||||
self.assertEqual(add_quant.type, quantization_node_name)
|
||||
|
||||
def _WeightInit(self, stddev):
|
||||
"""Returns truncated normal variable initializer.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user