Implement minimal selective quantization feature supporting denylisting.

This change provides a minimal config and implementation for denylisting quantizable units (== lifted functions).
With this change, users will be a able to denlylist quantization for specific quantizable units by specifying the config as;

```textpb
specs [
  {
    matcher { function_name { regex: "composite_dot_general.*" } }
    method { no_quantization{} }
  },
]
```

This change also includes an implementation of `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`, which is `LiftQuantizableSpotsAsFunctionsPass` with predefined `QuantizationSpecs`.

PiperOrigin-RevId: 607221399
This commit is contained in:
Dan Suh
2024-02-14 23:02:38 -08:00
committed by TensorFlower Gardener
parent c8df4622b6
commit 1ee3d7cb39
13 changed files with 609 additions and 13 deletions

View File

@@ -124,6 +124,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@local_tsl//tsl/platform:regexp",
"@local_tsl//tsl/platform:str_util",
"@local_tsl//tsl/protobuf:protos_all_cc",
"@local_xla//xla/mlir_hlo",
@@ -473,6 +474,7 @@ gentbl_cc_library(
cc_library(
name = "test_passes",
srcs = [
"passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc",
"passes/testing/test_post_calibration_component.cc",
"passes/testing/test_pre_calibration_component.cc",
"passes/testing/test_tf_to_stablehlo_pass.cc",
@@ -482,6 +484,7 @@ cc_library(
],
compatible_with = get_compatible_with_portable(),
deps = [
":passes",
":quantization_config_proto_cc",
":stablehlo_test_passes_inc_gen",
"//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
@@ -492,13 +495,17 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:protobuf",
"@local_xla//xla/mlir_hlo",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",

View File

@@ -28,12 +28,14 @@ limitations under the License.
namespace mlir::quant::stablehlo {
using ::stablehlo::quantization::PipelineConfig;
using ::stablehlo::quantization::QuantizationSpecs;
using ::stablehlo::quantization::StaticRangePtqPreset;
using ::tensorflow::quantization::CalibrationOptions;
void AddPreCalibrationPasses(OpPassManager& pm,
const CalibrationOptions& calibration_options) {
pm.addPass(createLiftQuantizableSpotsAsFunctionsPass());
const CalibrationOptions& calibration_options,
const QuantizationSpecs& quantization_specs) {
pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs));
pm.addNestedPass<func::FuncOp>(
CreateInsertCustomAggregationOpsPass(calibration_options));
pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass());

View File

@@ -25,7 +25,8 @@ namespace mlir::quant::stablehlo {
// required to collect tensor statistics.
void AddPreCalibrationPasses(
OpPassManager& pm,
const ::tensorflow::quantization::CalibrationOptions& calibration_options);
const ::tensorflow::quantization::CalibrationOptions& calibration_options,
const ::stablehlo::quantization::QuantizationSpecs& specs);
// Adds passes for static-range quantization post-calibration. Utilizes tensor
// statistics collected from the calibration step and performs quantization.

View File

@@ -42,8 +42,8 @@ absl::StatusOr<ModuleOp> PreCalibrationComponent::Run(
ModuleOp module_op, const QuantizationConfig& config) {
TF_RETURN_IF_ERROR(RunPasses(
kName, /*add_passes_func=*/
[this](PassManager& pm) {
AddPreCalibrationPasses(pm, calibration_options_);
[&config, this](PassManager& pm) {
AddPreCalibrationPasses(pm, calibration_options_, config.specs());
},
*ctx_, module_op));
return module_op;

View File

@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@@ -31,6 +35,11 @@ limitations under the License.
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tsl/platform/regexp.h" // IWYU pragma: keep
#define DEBUG_TYPE "lift_quantizable_spots_as_functions"
namespace mlir::quant::stablehlo {
@@ -39,13 +48,16 @@ namespace mlir::quant::stablehlo {
namespace {
using ::stablehlo::quantization::FunctionNameMatcherSpec;
using ::stablehlo::quantization::Method;
using ::stablehlo::quantization::QuantizationSpec;
using ::stablehlo::quantization::QuantizationSpecs;
// TODO - b/303543789: Move the helper functions below to a separate util.
// Fetches the default or null attribute, used for pattern matching.
Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) {
if (!attr) {
return builder.getStringAttr(kNullAttributeValue);
}
return attr;
if (attr) return attr;
return builder.getStringAttr(kNullAttributeValue);
}
// Checks whether the value of a constant equals the given float, regardless
@@ -62,6 +74,12 @@ bool FloatValueEquals(const Attribute& attr, const double value) {
});
}
// Lifts quantizable units as separate functions, thereby identifying the
// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how
// quantizable units are lifted.
//
// FileCheck test cases using various `QuantizationSpecs` can be seen at
// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`.
class LiftQuantizableSpotsAsFunctionsPass
: public impl::LiftQuantizableSpotsAsFunctionsPassBase<
LiftQuantizableSpotsAsFunctionsPass> {
@@ -69,10 +87,19 @@ class LiftQuantizableSpotsAsFunctionsPass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
LiftQuantizableSpotsAsFunctionsPass)
explicit LiftQuantizableSpotsAsFunctionsPass() = default;
LiftQuantizableSpotsAsFunctionsPass() = default;
// Constructor with explicit user-provided `QuantizationSpecs`.
explicit LiftQuantizableSpotsAsFunctionsPass(
QuantizationSpecs quantization_specs)
: quantization_specs_(std::move(quantization_specs)) {}
private:
void runOnOperation() override;
// No explicit quantization spec is specified by default. Implicitly this
// means that all quantizable units will be identified and lifted.
QuantizationSpecs quantization_specs_{};
};
namespace simple_patterns {
@@ -83,6 +110,91 @@ namespace fusion_patterns {
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc"
}
// Returns a `func::FuncOp` in `module_op` (not nested) whose name matches
// `name`. Returns null if no such a function exists.
// TODO: b/307620778 - Factor out "FindMainFuncOp" functionality.
func::FuncOp FindFuncOp(ModuleOp module_op, const StringRef name) {
auto func_ops = module_op.getOps<func::FuncOp>();
auto func_itr = llvm::find_if(func_ops, [name](func::FuncOp func_op) {
return func_op.getName() == name;
});
if (func_itr == func_ops.end()) return {};
return *func_itr;
}
// Quantizable Unit matcher that uses lifted function's name for matching.
class FunctionNameMatcher {
public:
explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec)
: match_regex_(GetMatchRegex(spec)) {}
// Returns `true` when matched with the entry function of
// `xla_call_module_op`.
bool Match(TF::XlaCallModuleOp xla_call_module_op) const {
if (match_regex_ == nullptr) return false;
const std::string lifted_func_name =
xla_call_module_op->getAttrOfType<FlatSymbolRefAttr>("_entry_function")
.getValue()
.str();
return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT
}
private:
// Returns an owned `RE2` object that corresponds to the `spec`. Returns
// `nullptr` if the `spec` is invalid.
// NOLINTNEXTLINE - RE2 included via TSL regexp.h
std::unique_ptr<RE2> GetMatchRegex(const FunctionNameMatcherSpec& spec) {
const std::string& regex = spec.regex();
if (regex.empty()) return nullptr;
return std::make_unique<RE2>(regex); // NOLINT
}
// Regex object used for matching against a lifted function's name.
std::unique_ptr<RE2> match_regex_; // NOLINT
};
// Applies quantization spec to all matched lifted functions. At this point only
// denylisting (`NoQuantization`) will be applied if specs is nonempty.
// TODO: b/307620778 - Support more advanced selective quantization methods.
LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec,
ModuleOp module_op) {
func::FuncOp main_func = FindFuncOp(module_op, "main");
if (!main_func) return failure();
const Method& quantization_method = spec.method();
if (!quantization_method.has_no_quantization()) {
module_op->emitError() << "Unsupported quantization method: "
<< quantization_method.DebugString() << "\n";
return failure();
}
const FunctionNameMatcher matcher(spec.matcher().function_name());
for (auto xla_call_module_op : main_func.getOps<TF::XlaCallModuleOp>()) {
if (!matcher.Match(xla_call_module_op)) continue;
// Disable quantization when matched.
const std::string lifted_func_name =
xla_call_module_op->getAttrOfType<FlatSymbolRefAttr>("_entry_function")
.getValue()
.str();
func::FuncOp lifted_func = FindFuncOp(module_op, lifted_func_name);
// Remove relevant attributes that enable quantization. This essentially
// disables quantization for the matched `xla_call_module_op`.
xla_call_module_op->removeAttr("_original_entry_function");
xla_call_module_op->removeAttr("_tfl_quant_trait");
lifted_func->removeAttr("tf_quant.composite_function");
LLVM_DEBUG(llvm::dbgs() << "Disabled quantization for quantizable unit: "
<< lifted_func_name << "\n");
}
return success();
}
void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() {
MLIRContext* ctx = &getContext();
RewritePatternSet patterns(ctx);
@@ -101,8 +213,26 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() {
// Remove all attr_map attributes.
module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); });
// Perform selective quantization. Iterates over the quantization specs and
// applies quantization methods to each matched lifted function.
for (const QuantizationSpec& spec : quantization_specs_.specs()) {
if (failed(ApplyQuantizationSpec(spec, module_op))) {
signalPassFailure();
return;
}
}
}
} // namespace
// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined
// `QuantizationSpecs`.
std::unique_ptr<OperationPass<ModuleOp>>
CreateLiftQuantizableSpotsAsFunctionsPass(
const QuantizationSpecs& quantization_specs) {
return std::make_unique<LiftQuantizableSpotsAsFunctionsPass>(
quantization_specs);
}
} // namespace mlir::quant::stablehlo

View File

@@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h"
namespace mlir::quant::stablehlo {
@@ -45,6 +46,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizeWeightPass(
absl::StatusOr<std::string> ConvertSerializedStableHloModuleToBfloat16(
StringRef serialized_stablehlo_module);
std::unique_ptr<OperationPass<ModuleOp>>
CreateLiftQuantizableSpotsAsFunctionsPass(
const ::stablehlo::quantization::QuantizationSpecs& quantization_specs);
// Adds generated pass default constructors or options definitions.
#define GEN_PASS_DECL
// Adds generated pass registration functions.

View File

@@ -34,6 +34,7 @@ def LiftQuantizableSpotsAsFunctionsPass : Pass<"stablehlo-lift-quantizable-spots
that disperse values. (ex: convolution, dot_general)
}];
let dependentDialects = [
"mlir::func::FuncDialect",
"mlir::stablehlo::StablehloDialect",
"TF::TensorFlowDialect",
];

View File

@@ -61,3 +61,17 @@ def TestTFToStablehloPass : Pass<"stablehlo-test-tf-to-stablehlo", "mlir::Module
"mlir::sparse_tensor::SparseTensorDialect", "mlir::vhlo::VhloDialect",
];
}
def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass :
Pass<"stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs", "mlir::ModuleOp"> {
let summary = "Test-only pass for testing the LiftQuantizableSpotsAsFunctionsPass with a predefined QuantizationSpecs.";
let description = [{
This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but
has predefined `QuantizationSpecs` to make FileCheck testing easier.
}];
let dependentDialects = [
"mlir::func::FuncDialect",
"mlir::stablehlo::StablehloDialect",
"TF::TensorFlowDialect",
];
}

View File

@@ -0,0 +1,101 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep
#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep
#include "tsl/platform/protobuf.h" // IWYU pragma: keep
namespace mlir::quant::stablehlo::testing {
// NOLINTNEXTLINE - Automatically generated.
#define GEN_PASS_DEF_TESTLIFTQUANTIZABLESPOTSASFUNCTIONSWITHQUANTIZATIONSPECSPASS
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc"
namespace {
using ::stablehlo::quantization::QuantizationSpecs;
using ::tsl::protobuf::TextFormat;
// NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS.
using ::tsl::protobuf::io::ArrayInputStream;
// Configure `QuantizationSpecs` to disable quantization for all dot_general
// quantizable units.
constexpr absl::string_view kSpecsDisableAllDotGeneralByFuncName =
R"pb(specs
[ {
matcher { function_name { regex: "composite_dot_general_.*" } }
method { no_quantization {} }
}])pb";
class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass
: public impl::
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase<
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass> {
public:
using impl::TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase<
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass>::
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase;
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass)
private:
void runOnOperation() override;
};
// Parses a text proto into a `QuantizationSpecs` proto. Returns
// `InvalidArgumentError` if `text_proto` is invalid.
absl::StatusOr<QuantizationSpecs> ParseQuantizationSpecsTextProto(
const absl::string_view text_proto) {
QuantizationSpecs quantization_specs;
TextFormat::Parser parser;
ArrayInputStream input_stream(text_proto.data(), text_proto.size());
if (parser.Parse(&input_stream, &quantization_specs)) {
return quantization_specs;
}
return absl::InvalidArgumentError("Could not parse text proto.");
}
void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass::
runOnOperation() {
PassManager pass_manager{&getContext()};
const absl::StatusOr<QuantizationSpecs> quantization_specs =
ParseQuantizationSpecsTextProto(kSpecsDisableAllDotGeneralByFuncName);
if (!quantization_specs.ok()) {
signalPassFailure();
return;
}
pass_manager.addPass(
CreateLiftQuantizableSpotsAsFunctionsPass(*quantization_specs));
if (failed(pass_manager.run(getOperation()))) {
signalPassFailure();
return;
}
}
} // namespace
} // namespace mlir::quant::stablehlo::testing

View File

@@ -73,12 +73,19 @@ tf_py_strict_test(
":quantize_model_test_base",
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/module",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:nn_ops",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/saved_model:tag_constants",
"//tensorflow/python/types:core",
"@absl_py//absl/testing:parameterized",
],
)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import itertools
from typing import Optional, Sequence
from typing import Mapping, Optional, Sequence
from absl.testing import parameterized
import numpy as np
@@ -22,12 +22,19 @@ from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_
from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization
from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base
from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.types import core
def parameter_combinations(test_parameters):
@@ -319,7 +326,10 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
@parameterized.parameters(
parameter_combinations([{
'equation': ('abc,cde->abde', 'abc,dce->abde',),
'equation': (
'abc,cde->abde',
'abc,dce->abde',
),
'rng_seed': (82, 82732, 4444, 14),
}])
)
@@ -406,6 +416,239 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
config,
)
@test_util.run_in_graph_and_eager_modes
def test_ptq_denylist_basic(self):
"""Tests that the op is not quantized when no quantization is enabled."""
input_shape = (1, 2)
model = self._create_matmul_model(
input_shape,
weight_shape=(2, 3),
saved_model_path=self._input_saved_model_path,
)
rng = np.random.default_rng(1230)
random_tensor_gen_fn = lambda: rng.uniform(
low=0.0, high=1.0, size=input_shape
).astype(np.float32)
def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(50):
yield {'input_tensor': random_tensor_gen_fn()}
dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)
config = qc.QuantizationConfig(
static_range_ptq_preset=qc.StaticRangePtqPreset(
representative_datasets=[
qc.RepresentativeDatasetConfig(
tf_record=qc.TfRecordFile(path=dataset_path)
)
]
),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
# Disable quantization for the quantizable unit (lifted function) whose
# function name starts with "composite_dot_general".
specs=qc.QuantizationSpecs(
specs=[
qc.QuantizationSpec(
matcher=qc.MatcherSpec(
function_name=qc.FunctionNameMatcherSpec(
regex='composite_dot_general.*'
)
),
method=qc.Method(no_quantization={}),
)
]
),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)
input_data = ops.convert_to_tensor(random_tensor_gen_fn())
expected_outputs = model.matmul(input_data)
root = load.load(self._output_saved_model_path)
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
new_outputs = root.signatures['serving_default'](
input_tensor=ops.convert_to_tensor(input_data)
)
# Indirectly tests that the model is not quantized by asserting that there
# are negligible numeric difference.
self.assertAllClose(new_outputs, expected_outputs, rtol=0.000001)
@test_util.run_in_graph_and_eager_modes
def test_ptq_selective_denylist(self):
"""Tests that the op is not quantized when no quantization is enabled."""
rng = np.random.default_rng(1230)
random_tensor_gen_fn = lambda shape: rng.uniform(
low=-1.0, high=1.0, size=shape
).astype(np.float32)
class TwoMatmulModel(module.Module):
"""A model with two matmul ops."""
@def_function.function
def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
"""Performs a matrix multiplication.
Args:
input_tensor: Input tensor to matmul with the filter.
Returns:
A 'output' -> output tensor mapping
"""
out = math_ops.matmul(input_tensor, random_tensor_gen_fn((2, 3)))
out = math_ops.matmul(out, random_tensor_gen_fn((3, 4)))
return {'output': out}
model = TwoMatmulModel()
input_shape = (1, 2)
save.save(
model,
self._input_saved_model_path,
signatures=model.matmul.get_concrete_function(
tensor_spec.TensorSpec(
shape=input_shape, dtype=dtypes.float32, name='input_tensor'
)
),
)
def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(50):
yield {'input_tensor': random_tensor_gen_fn(input_shape)}
dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)
config = qc.QuantizationConfig(
static_range_ptq_preset=qc.StaticRangePtqPreset(
representative_datasets=[
qc.RepresentativeDatasetConfig(
tf_record=qc.TfRecordFile(path=dataset_path)
),
],
),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
# Disable quantization for the quantizable unit (lifted function) whose
# function name matches "composite_dot_general_fn_1".
# "composite_dot_general_fn_2" will be quantized.
specs=qc.QuantizationSpecs(
specs=[
qc.QuantizationSpec(
matcher=qc.MatcherSpec(
function_name=qc.FunctionNameMatcherSpec(
regex='composite_dot_general_fn_1'
)
),
method=qc.Method(no_quantization={}),
)
]
),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)
input_data = ops.convert_to_tensor(random_tensor_gen_fn(input_shape))
expected_outputs = model.matmul(input_data)
root = load.load(self._output_saved_model_path)
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
new_outputs = root.signatures['serving_default'](
input_tensor=ops.convert_to_tensor(input_data)
)
# Indirectly tests that the model is only partially quantized.
self.assertAllClose(new_outputs, expected_outputs, rtol=0.011)
@test_util.run_in_graph_and_eager_modes
def test_ptq_quantization_method_not_applied_when_matcher_mismatch(self):
"""Tests that quantization method is not applied to unmatched units."""
input_shape = (1, 2)
model = self._create_matmul_model(
input_shape,
weight_shape=(2, 3),
saved_model_path=self._input_saved_model_path,
)
rng = np.random.default_rng(1230)
random_tensor_gen_fn = lambda: rng.uniform(
low=0.0, high=1.0, size=input_shape
).astype(np.float32)
def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(50):
yield {'input_tensor': random_tensor_gen_fn()}
dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)
config = qc.QuantizationConfig(
static_range_ptq_preset=qc.StaticRangePtqPreset(
representative_datasets=[
qc.RepresentativeDatasetConfig(
tf_record=qc.TfRecordFile(path=dataset_path)
)
]
),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
specs=qc.QuantizationSpecs(
specs=[
qc.QuantizationSpec(
# Provide a regex that wouldn't match any quantizable units.
matcher=qc.MatcherSpec(
function_name=qc.FunctionNameMatcherSpec(
regex='.*invalid_function_name.*'
),
),
method=qc.Method(no_quantization={}),
),
],
),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)
input_data = ops.convert_to_tensor(random_tensor_gen_fn())
expected_outputs = model.matmul(input_data)
root = load.load(self._output_saved_model_path)
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
new_outputs = root.signatures['serving_default'](
input_tensor=ops.convert_to_tensor(input_data)
)
# Tests that the quantized graph outputs similar values. They also shouldn't
# be exactly the same. Indirectly proves that the `FunctionNameMatcherSpec`
# with regex '.*invalid_function_name.*' did not match the quantizable unit.
self.assertAllClose(new_outputs, expected_outputs, rtol=0.04)
self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001)
if __name__ == '__main__':
test.main()

View File

@@ -63,10 +63,68 @@ message PipelineConfig {
optional bool unpack_quantized_types = 1;
}
// A quantization method representing "do not quantize". Mostly used for
// denylisting quantizable units from quantization.
message NoQuantization {}
// Represents a matching method that matches quantizable units by lifted
// functions' names.
message FunctionNameMatcherSpec {
// Regular expression to match lifted functions' names. Underlying regex
// engine uses re2, which accepts a subset of PCRE. See
// https://github.com/google/re2/wiki/Syntax for details.
string regex = 1;
}
// Matcher specification for identifying quantizable units.
message MatcherSpec {
// Matches lifted functions by their names.
FunctionNameMatcherSpec function_name = 1;
}
// Specifies how to quantize matched quantizable units.
message Method {
NoQuantization no_quantization = 1;
}
// A QuantizationSpec is essentially a (matcher spec, quantization method) pair,
// where the matcher spec is used to identify quantizable units and the
// quantization method specifies what type of quantization to apply on the
// matched quantizable units.
// Next ID: 3
message QuantizationSpec {
// Configures matchers for identifying quantizable units. Matched quantizable
// units will be quantized according to `method`.
MatcherSpec matcher = 1;
// Specifies how to quantize the matched quantizable units.
Method method = 2;
}
// Quantization specifications. A simple wrapper around a sequence of
// `QuantizationSpec`s so that specs can be easily passed around or represented
// as a textproto.
// Next ID: 2
message QuantizationSpecs {
// List of `QuantizationSpec`s. Later spec in the sequence takes precedence.
//
// NOTE: Tie-breaking mechanism is not yet supported. Providing multiple
// `QuantizationSpec` with conflicting quantizable units may result in
// undefined behavior.
// TODO: b/307620778 - Support tie-breaking for conflicting specs.
repeated QuantizationSpec specs = 1;
}
// Quantization configuration for StableHLO Quantizer. This is the primary
// message containing all configurable options.
// Next ID: 4
// Next ID: 5
message QuantizationConfig {
// Config presets provide predefined popular or common quantization specs.
// Lightweight users may choose one of the presets for quick experiments. Each
// preset is completely represented by `QuantizationSpecs`. When extra entries
// in `QuantizationSpecs` are provided along with a preset, then the preset
// will be overridden for the quantizable units matched by those additional
// `QuantizationSpec`s.
oneof preset {
// Performs best-effort static-range post-training quantization (PTQ).
StaticRangePtqPreset static_range_ptq_preset = 1;
@@ -77,4 +135,6 @@ message QuantizationConfig {
// Configures the graph transformation pipeline for quantization.
PipelineConfig pipeline_config = 3;
QuantizationSpecs specs = 4;
}

View File

@@ -0,0 +1,25 @@
// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs \
// RUN: -split-input-file | FileCheck %s
// CHECK: @main
func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
%0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
%1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
return %1 : tensor<1x1x64xf32>
}
// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp
// is missing attributes required for quantization.
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
// CHECK-SAME: {_entry_function = @composite_dot_general_fn_1, {{.*}}}
// CHECK-NOT: _original_entry_function
// CHECK-NOT: _tfl_quant_trait
// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
// CHECK: }
// CHECK-LABEL: private @composite_dot_general_fn_1
// CHECK-NOT: tf_quant.composite_function
// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1
// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32>
// CHECK: }