mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Implement minimal selective quantization feature supporting denylisting.
This change provides a minimal config and implementation for denylisting quantizable units (== lifted functions).
With this change, users will be a able to denlylist quantization for specific quantizable units by specifying the config as;
```textpb
specs [
{
matcher { function_name { regex: "composite_dot_general.*" } }
method { no_quantization{} }
},
]
```
This change also includes an implementation of `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`, which is `LiftQuantizableSpotsAsFunctionsPass` with predefined `QuantizationSpecs`.
PiperOrigin-RevId: 607221399
This commit is contained in:
committed by
TensorFlower Gardener
parent
c8df4622b6
commit
1ee3d7cb39
@@ -124,6 +124,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@local_tsl//tsl/platform:regexp",
|
||||
"@local_tsl//tsl/platform:str_util",
|
||||
"@local_tsl//tsl/protobuf:protos_all_cc",
|
||||
"@local_xla//xla/mlir_hlo",
|
||||
@@ -473,6 +474,7 @@ gentbl_cc_library(
|
||||
cc_library(
|
||||
name = "test_passes",
|
||||
srcs = [
|
||||
"passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc",
|
||||
"passes/testing/test_post_calibration_component.cc",
|
||||
"passes/testing/test_pre_calibration_component.cc",
|
||||
"passes/testing/test_tf_to_stablehlo_pass.cc",
|
||||
@@ -482,6 +484,7 @@ cc_library(
|
||||
],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
deps = [
|
||||
":passes",
|
||||
":quantization_config_proto_cc",
|
||||
":stablehlo_test_passes_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
|
||||
@@ -492,13 +495,17 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SparseTensorDialect",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@local_tsl//tsl/platform:protobuf",
|
||||
"@local_xla//xla/mlir_hlo",
|
||||
"@stablehlo//:chlo_ops",
|
||||
"@stablehlo//:stablehlo_ops",
|
||||
|
||||
@@ -28,12 +28,14 @@ limitations under the License.
|
||||
namespace mlir::quant::stablehlo {
|
||||
|
||||
using ::stablehlo::quantization::PipelineConfig;
|
||||
using ::stablehlo::quantization::QuantizationSpecs;
|
||||
using ::stablehlo::quantization::StaticRangePtqPreset;
|
||||
using ::tensorflow::quantization::CalibrationOptions;
|
||||
|
||||
void AddPreCalibrationPasses(OpPassManager& pm,
|
||||
const CalibrationOptions& calibration_options) {
|
||||
pm.addPass(createLiftQuantizableSpotsAsFunctionsPass());
|
||||
const CalibrationOptions& calibration_options,
|
||||
const QuantizationSpecs& quantization_specs) {
|
||||
pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs));
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
CreateInsertCustomAggregationOpsPass(calibration_options));
|
||||
pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass());
|
||||
|
||||
@@ -25,7 +25,8 @@ namespace mlir::quant::stablehlo {
|
||||
// required to collect tensor statistics.
|
||||
void AddPreCalibrationPasses(
|
||||
OpPassManager& pm,
|
||||
const ::tensorflow::quantization::CalibrationOptions& calibration_options);
|
||||
const ::tensorflow::quantization::CalibrationOptions& calibration_options,
|
||||
const ::stablehlo::quantization::QuantizationSpecs& specs);
|
||||
|
||||
// Adds passes for static-range quantization post-calibration. Utilizes tensor
|
||||
// statistics collected from the calibration step and performs quantization.
|
||||
|
||||
@@ -42,8 +42,8 @@ absl::StatusOr<ModuleOp> PreCalibrationComponent::Run(
|
||||
ModuleOp module_op, const QuantizationConfig& config) {
|
||||
TF_RETURN_IF_ERROR(RunPasses(
|
||||
kName, /*add_passes_func=*/
|
||||
[this](PassManager& pm) {
|
||||
AddPreCalibrationPasses(pm, calibration_options_);
|
||||
[&config, this](PassManager& pm) {
|
||||
AddPreCalibrationPasses(pm, calibration_options_, config.specs());
|
||||
},
|
||||
*ctx_, module_op));
|
||||
return module_op;
|
||||
|
||||
@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@@ -31,6 +35,11 @@ limitations under the License.
|
||||
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
|
||||
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tsl/platform/regexp.h" // IWYU pragma: keep
|
||||
|
||||
#define DEBUG_TYPE "lift_quantizable_spots_as_functions"
|
||||
|
||||
namespace mlir::quant::stablehlo {
|
||||
|
||||
@@ -39,13 +48,16 @@ namespace mlir::quant::stablehlo {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::stablehlo::quantization::FunctionNameMatcherSpec;
|
||||
using ::stablehlo::quantization::Method;
|
||||
using ::stablehlo::quantization::QuantizationSpec;
|
||||
using ::stablehlo::quantization::QuantizationSpecs;
|
||||
|
||||
// TODO - b/303543789: Move the helper functions below to a separate util.
|
||||
// Fetches the default or null attribute, used for pattern matching.
|
||||
Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) {
|
||||
if (!attr) {
|
||||
return builder.getStringAttr(kNullAttributeValue);
|
||||
}
|
||||
return attr;
|
||||
if (attr) return attr;
|
||||
return builder.getStringAttr(kNullAttributeValue);
|
||||
}
|
||||
|
||||
// Checks whether the value of a constant equals the given float, regardless
|
||||
@@ -62,6 +74,12 @@ bool FloatValueEquals(const Attribute& attr, const double value) {
|
||||
});
|
||||
}
|
||||
|
||||
// Lifts quantizable units as separate functions, thereby identifying the
|
||||
// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how
|
||||
// quantizable units are lifted.
|
||||
//
|
||||
// FileCheck test cases using various `QuantizationSpecs` can be seen at
|
||||
// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`.
|
||||
class LiftQuantizableSpotsAsFunctionsPass
|
||||
: public impl::LiftQuantizableSpotsAsFunctionsPassBase<
|
||||
LiftQuantizableSpotsAsFunctionsPass> {
|
||||
@@ -69,10 +87,19 @@ class LiftQuantizableSpotsAsFunctionsPass
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
LiftQuantizableSpotsAsFunctionsPass)
|
||||
|
||||
explicit LiftQuantizableSpotsAsFunctionsPass() = default;
|
||||
LiftQuantizableSpotsAsFunctionsPass() = default;
|
||||
|
||||
// Constructor with explicit user-provided `QuantizationSpecs`.
|
||||
explicit LiftQuantizableSpotsAsFunctionsPass(
|
||||
QuantizationSpecs quantization_specs)
|
||||
: quantization_specs_(std::move(quantization_specs)) {}
|
||||
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
|
||||
// No explicit quantization spec is specified by default. Implicitly this
|
||||
// means that all quantizable units will be identified and lifted.
|
||||
QuantizationSpecs quantization_specs_{};
|
||||
};
|
||||
|
||||
namespace simple_patterns {
|
||||
@@ -83,6 +110,91 @@ namespace fusion_patterns {
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc"
|
||||
}
|
||||
|
||||
// Returns a `func::FuncOp` in `module_op` (not nested) whose name matches
|
||||
// `name`. Returns null if no such a function exists.
|
||||
// TODO: b/307620778 - Factor out "FindMainFuncOp" functionality.
|
||||
func::FuncOp FindFuncOp(ModuleOp module_op, const StringRef name) {
|
||||
auto func_ops = module_op.getOps<func::FuncOp>();
|
||||
auto func_itr = llvm::find_if(func_ops, [name](func::FuncOp func_op) {
|
||||
return func_op.getName() == name;
|
||||
});
|
||||
|
||||
if (func_itr == func_ops.end()) return {};
|
||||
return *func_itr;
|
||||
}
|
||||
|
||||
// Quantizable Unit matcher that uses lifted function's name for matching.
|
||||
class FunctionNameMatcher {
|
||||
public:
|
||||
explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec)
|
||||
: match_regex_(GetMatchRegex(spec)) {}
|
||||
|
||||
// Returns `true` when matched with the entry function of
|
||||
// `xla_call_module_op`.
|
||||
bool Match(TF::XlaCallModuleOp xla_call_module_op) const {
|
||||
if (match_regex_ == nullptr) return false;
|
||||
|
||||
const std::string lifted_func_name =
|
||||
xla_call_module_op->getAttrOfType<FlatSymbolRefAttr>("_entry_function")
|
||||
.getValue()
|
||||
.str();
|
||||
|
||||
return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns an owned `RE2` object that corresponds to the `spec`. Returns
|
||||
// `nullptr` if the `spec` is invalid.
|
||||
// NOLINTNEXTLINE - RE2 included via TSL regexp.h
|
||||
std::unique_ptr<RE2> GetMatchRegex(const FunctionNameMatcherSpec& spec) {
|
||||
const std::string& regex = spec.regex();
|
||||
if (regex.empty()) return nullptr;
|
||||
|
||||
return std::make_unique<RE2>(regex); // NOLINT
|
||||
}
|
||||
|
||||
// Regex object used for matching against a lifted function's name.
|
||||
std::unique_ptr<RE2> match_regex_; // NOLINT
|
||||
};
|
||||
|
||||
// Applies quantization spec to all matched lifted functions. At this point only
|
||||
// denylisting (`NoQuantization`) will be applied if specs is nonempty.
|
||||
// TODO: b/307620778 - Support more advanced selective quantization methods.
|
||||
LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec,
|
||||
ModuleOp module_op) {
|
||||
func::FuncOp main_func = FindFuncOp(module_op, "main");
|
||||
if (!main_func) return failure();
|
||||
|
||||
const Method& quantization_method = spec.method();
|
||||
if (!quantization_method.has_no_quantization()) {
|
||||
module_op->emitError() << "Unsupported quantization method: "
|
||||
<< quantization_method.DebugString() << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const FunctionNameMatcher matcher(spec.matcher().function_name());
|
||||
for (auto xla_call_module_op : main_func.getOps<TF::XlaCallModuleOp>()) {
|
||||
if (!matcher.Match(xla_call_module_op)) continue;
|
||||
|
||||
// Disable quantization when matched.
|
||||
const std::string lifted_func_name =
|
||||
xla_call_module_op->getAttrOfType<FlatSymbolRefAttr>("_entry_function")
|
||||
.getValue()
|
||||
.str();
|
||||
func::FuncOp lifted_func = FindFuncOp(module_op, lifted_func_name);
|
||||
|
||||
// Remove relevant attributes that enable quantization. This essentially
|
||||
// disables quantization for the matched `xla_call_module_op`.
|
||||
xla_call_module_op->removeAttr("_original_entry_function");
|
||||
xla_call_module_op->removeAttr("_tfl_quant_trait");
|
||||
lifted_func->removeAttr("tf_quant.composite_function");
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Disabled quantization for quantizable unit: "
|
||||
<< lifted_func_name << "\n");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() {
|
||||
MLIRContext* ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
@@ -101,8 +213,26 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() {
|
||||
|
||||
// Remove all attr_map attributes.
|
||||
module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); });
|
||||
|
||||
// Perform selective quantization. Iterates over the quantization specs and
|
||||
// applies quantization methods to each matched lifted function.
|
||||
for (const QuantizationSpec& spec : quantization_specs_.specs()) {
|
||||
if (failed(ApplyQuantizationSpec(spec, module_op))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined
|
||||
// `QuantizationSpecs`.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateLiftQuantizableSpotsAsFunctionsPass(
|
||||
const QuantizationSpecs& quantization_specs) {
|
||||
return std::make_unique<LiftQuantizableSpotsAsFunctionsPass>(
|
||||
quantization_specs);
|
||||
}
|
||||
|
||||
} // namespace mlir::quant::stablehlo
|
||||
|
||||
@@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h"
|
||||
|
||||
namespace mlir::quant::stablehlo {
|
||||
@@ -45,6 +46,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizeWeightPass(
|
||||
absl::StatusOr<std::string> ConvertSerializedStableHloModuleToBfloat16(
|
||||
StringRef serialized_stablehlo_module);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateLiftQuantizableSpotsAsFunctionsPass(
|
||||
const ::stablehlo::quantization::QuantizationSpecs& quantization_specs);
|
||||
|
||||
// Adds generated pass default constructors or options definitions.
|
||||
#define GEN_PASS_DECL
|
||||
// Adds generated pass registration functions.
|
||||
|
||||
@@ -34,6 +34,7 @@ def LiftQuantizableSpotsAsFunctionsPass : Pass<"stablehlo-lift-quantizable-spots
|
||||
that disperse values. (ex: convolution, dot_general)
|
||||
}];
|
||||
let dependentDialects = [
|
||||
"mlir::func::FuncDialect",
|
||||
"mlir::stablehlo::StablehloDialect",
|
||||
"TF::TensorFlowDialect",
|
||||
];
|
||||
|
||||
@@ -61,3 +61,17 @@ def TestTFToStablehloPass : Pass<"stablehlo-test-tf-to-stablehlo", "mlir::Module
|
||||
"mlir::sparse_tensor::SparseTensorDialect", "mlir::vhlo::VhloDialect",
|
||||
];
|
||||
}
|
||||
|
||||
def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass :
|
||||
Pass<"stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs", "mlir::ModuleOp"> {
|
||||
let summary = "Test-only pass for testing the LiftQuantizableSpotsAsFunctionsPass with a predefined QuantizationSpecs.";
|
||||
let description = [{
|
||||
This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but
|
||||
has predefined `QuantizationSpecs` to make FileCheck testing easier.
|
||||
}];
|
||||
let dependentDialects = [
|
||||
"mlir::func::FuncDialect",
|
||||
"mlir::stablehlo::StablehloDialect",
|
||||
"TF::TensorFlowDialect",
|
||||
];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Support/TypeID.h" // from @llvm-project
|
||||
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep
|
||||
#include "tsl/platform/protobuf.h" // IWYU pragma: keep
|
||||
|
||||
namespace mlir::quant::stablehlo::testing {
|
||||
|
||||
// NOLINTNEXTLINE - Automatically generated.
|
||||
#define GEN_PASS_DEF_TESTLIFTQUANTIZABLESPOTSASFUNCTIONSWITHQUANTIZATIONSPECSPASS
|
||||
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::stablehlo::quantization::QuantizationSpecs;
|
||||
using ::tsl::protobuf::TextFormat;
|
||||
// NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS.
|
||||
using ::tsl::protobuf::io::ArrayInputStream;
|
||||
|
||||
// Configure `QuantizationSpecs` to disable quantization for all dot_general
|
||||
// quantizable units.
|
||||
constexpr absl::string_view kSpecsDisableAllDotGeneralByFuncName =
|
||||
R"pb(specs
|
||||
[ {
|
||||
matcher { function_name { regex: "composite_dot_general_.*" } }
|
||||
method { no_quantization {} }
|
||||
}])pb";
|
||||
|
||||
class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass
|
||||
: public impl::
|
||||
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase<
|
||||
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass> {
|
||||
public:
|
||||
using impl::TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase<
|
||||
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass>::
|
||||
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase;
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass)
|
||||
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Parses a text proto into a `QuantizationSpecs` proto. Returns
|
||||
// `InvalidArgumentError` if `text_proto` is invalid.
|
||||
absl::StatusOr<QuantizationSpecs> ParseQuantizationSpecsTextProto(
|
||||
const absl::string_view text_proto) {
|
||||
QuantizationSpecs quantization_specs;
|
||||
TextFormat::Parser parser;
|
||||
ArrayInputStream input_stream(text_proto.data(), text_proto.size());
|
||||
if (parser.Parse(&input_stream, &quantization_specs)) {
|
||||
return quantization_specs;
|
||||
}
|
||||
return absl::InvalidArgumentError("Could not parse text proto.");
|
||||
}
|
||||
|
||||
void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass::
|
||||
runOnOperation() {
|
||||
PassManager pass_manager{&getContext()};
|
||||
|
||||
const absl::StatusOr<QuantizationSpecs> quantization_specs =
|
||||
ParseQuantizationSpecsTextProto(kSpecsDisableAllDotGeneralByFuncName);
|
||||
if (!quantization_specs.ok()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
pass_manager.addPass(
|
||||
CreateLiftQuantizableSpotsAsFunctionsPass(*quantization_specs));
|
||||
|
||||
if (failed(pass_manager.run(getOperation()))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mlir::quant::stablehlo::testing
|
||||
@@ -73,12 +73,19 @@ tf_py_strict_test(
|
||||
":quantize_model_test_base",
|
||||
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/framework:test_lib",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops:math_ops",
|
||||
"//tensorflow/python/ops:nn_ops",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/saved_model:tag_constants",
|
||||
"//tensorflow/python/types:core",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import itertools
|
||||
from typing import Optional, Sequence
|
||||
from typing import Mapping, Optional, Sequence
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@@ -22,12 +22,19 @@ from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_
|
||||
from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization
|
||||
from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base
|
||||
from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.types import core
|
||||
|
||||
|
||||
def parameter_combinations(test_parameters):
|
||||
@@ -319,7 +326,10 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
|
||||
|
||||
@parameterized.parameters(
|
||||
parameter_combinations([{
|
||||
'equation': ('abc,cde->abde', 'abc,dce->abde',),
|
||||
'equation': (
|
||||
'abc,cde->abde',
|
||||
'abc,dce->abde',
|
||||
),
|
||||
'rng_seed': (82, 82732, 4444, 14),
|
||||
}])
|
||||
)
|
||||
@@ -406,6 +416,239 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
|
||||
config,
|
||||
)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_ptq_denylist_basic(self):
|
||||
"""Tests that the op is not quantized when no quantization is enabled."""
|
||||
input_shape = (1, 2)
|
||||
model = self._create_matmul_model(
|
||||
input_shape,
|
||||
weight_shape=(2, 3),
|
||||
saved_model_path=self._input_saved_model_path,
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(1230)
|
||||
random_tensor_gen_fn = lambda: rng.uniform(
|
||||
low=0.0, high=1.0, size=input_shape
|
||||
).astype(np.float32)
|
||||
|
||||
def data_gen() -> repr_dataset.RepresentativeDataset:
|
||||
for _ in range(50):
|
||||
yield {'input_tensor': random_tensor_gen_fn()}
|
||||
|
||||
dataset_path = self.create_tempfile('tfrecord').full_path
|
||||
path_map = {'serving_default': dataset_path}
|
||||
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
|
||||
{'serving_default': data_gen()}
|
||||
)
|
||||
|
||||
config = qc.QuantizationConfig(
|
||||
static_range_ptq_preset=qc.StaticRangePtqPreset(
|
||||
representative_datasets=[
|
||||
qc.RepresentativeDatasetConfig(
|
||||
tf_record=qc.TfRecordFile(path=dataset_path)
|
||||
)
|
||||
]
|
||||
),
|
||||
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
|
||||
# Disable quantization for the quantizable unit (lifted function) whose
|
||||
# function name starts with "composite_dot_general".
|
||||
specs=qc.QuantizationSpecs(
|
||||
specs=[
|
||||
qc.QuantizationSpec(
|
||||
matcher=qc.MatcherSpec(
|
||||
function_name=qc.FunctionNameMatcherSpec(
|
||||
regex='composite_dot_general.*'
|
||||
)
|
||||
),
|
||||
method=qc.Method(no_quantization={}),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
quantization.quantize_saved_model(
|
||||
self._input_saved_model_path,
|
||||
self._output_saved_model_path,
|
||||
config,
|
||||
)
|
||||
|
||||
input_data = ops.convert_to_tensor(random_tensor_gen_fn())
|
||||
expected_outputs = model.matmul(input_data)
|
||||
|
||||
root = load.load(self._output_saved_model_path)
|
||||
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
|
||||
|
||||
new_outputs = root.signatures['serving_default'](
|
||||
input_tensor=ops.convert_to_tensor(input_data)
|
||||
)
|
||||
|
||||
# Indirectly tests that the model is not quantized by asserting that there
|
||||
# are negligible numeric difference.
|
||||
self.assertAllClose(new_outputs, expected_outputs, rtol=0.000001)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_ptq_selective_denylist(self):
|
||||
"""Tests that the op is not quantized when no quantization is enabled."""
|
||||
|
||||
rng = np.random.default_rng(1230)
|
||||
random_tensor_gen_fn = lambda shape: rng.uniform(
|
||||
low=-1.0, high=1.0, size=shape
|
||||
).astype(np.float32)
|
||||
|
||||
class TwoMatmulModel(module.Module):
|
||||
"""A model with two matmul ops."""
|
||||
|
||||
@def_function.function
|
||||
def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
|
||||
"""Performs a matrix multiplication.
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor to matmul with the filter.
|
||||
|
||||
Returns:
|
||||
A 'output' -> output tensor mapping
|
||||
"""
|
||||
out = math_ops.matmul(input_tensor, random_tensor_gen_fn((2, 3)))
|
||||
out = math_ops.matmul(out, random_tensor_gen_fn((3, 4)))
|
||||
return {'output': out}
|
||||
|
||||
model = TwoMatmulModel()
|
||||
input_shape = (1, 2)
|
||||
|
||||
save.save(
|
||||
model,
|
||||
self._input_saved_model_path,
|
||||
signatures=model.matmul.get_concrete_function(
|
||||
tensor_spec.TensorSpec(
|
||||
shape=input_shape, dtype=dtypes.float32, name='input_tensor'
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def data_gen() -> repr_dataset.RepresentativeDataset:
|
||||
for _ in range(50):
|
||||
yield {'input_tensor': random_tensor_gen_fn(input_shape)}
|
||||
|
||||
dataset_path = self.create_tempfile('tfrecord').full_path
|
||||
path_map = {'serving_default': dataset_path}
|
||||
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
|
||||
{'serving_default': data_gen()}
|
||||
)
|
||||
|
||||
config = qc.QuantizationConfig(
|
||||
static_range_ptq_preset=qc.StaticRangePtqPreset(
|
||||
representative_datasets=[
|
||||
qc.RepresentativeDatasetConfig(
|
||||
tf_record=qc.TfRecordFile(path=dataset_path)
|
||||
),
|
||||
],
|
||||
),
|
||||
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
|
||||
# Disable quantization for the quantizable unit (lifted function) whose
|
||||
# function name matches "composite_dot_general_fn_1".
|
||||
# "composite_dot_general_fn_2" will be quantized.
|
||||
specs=qc.QuantizationSpecs(
|
||||
specs=[
|
||||
qc.QuantizationSpec(
|
||||
matcher=qc.MatcherSpec(
|
||||
function_name=qc.FunctionNameMatcherSpec(
|
||||
regex='composite_dot_general_fn_1'
|
||||
)
|
||||
),
|
||||
method=qc.Method(no_quantization={}),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
quantization.quantize_saved_model(
|
||||
self._input_saved_model_path,
|
||||
self._output_saved_model_path,
|
||||
config,
|
||||
)
|
||||
|
||||
input_data = ops.convert_to_tensor(random_tensor_gen_fn(input_shape))
|
||||
expected_outputs = model.matmul(input_data)
|
||||
|
||||
root = load.load(self._output_saved_model_path)
|
||||
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
|
||||
|
||||
new_outputs = root.signatures['serving_default'](
|
||||
input_tensor=ops.convert_to_tensor(input_data)
|
||||
)
|
||||
|
||||
# Indirectly tests that the model is only partially quantized.
|
||||
self.assertAllClose(new_outputs, expected_outputs, rtol=0.011)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_ptq_quantization_method_not_applied_when_matcher_mismatch(self):
|
||||
"""Tests that quantization method is not applied to unmatched units."""
|
||||
input_shape = (1, 2)
|
||||
model = self._create_matmul_model(
|
||||
input_shape,
|
||||
weight_shape=(2, 3),
|
||||
saved_model_path=self._input_saved_model_path,
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(1230)
|
||||
random_tensor_gen_fn = lambda: rng.uniform(
|
||||
low=0.0, high=1.0, size=input_shape
|
||||
).astype(np.float32)
|
||||
|
||||
def data_gen() -> repr_dataset.RepresentativeDataset:
|
||||
for _ in range(50):
|
||||
yield {'input_tensor': random_tensor_gen_fn()}
|
||||
|
||||
dataset_path = self.create_tempfile('tfrecord').full_path
|
||||
path_map = {'serving_default': dataset_path}
|
||||
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
|
||||
{'serving_default': data_gen()}
|
||||
)
|
||||
|
||||
config = qc.QuantizationConfig(
|
||||
static_range_ptq_preset=qc.StaticRangePtqPreset(
|
||||
representative_datasets=[
|
||||
qc.RepresentativeDatasetConfig(
|
||||
tf_record=qc.TfRecordFile(path=dataset_path)
|
||||
)
|
||||
]
|
||||
),
|
||||
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
|
||||
specs=qc.QuantizationSpecs(
|
||||
specs=[
|
||||
qc.QuantizationSpec(
|
||||
# Provide a regex that wouldn't match any quantizable units.
|
||||
matcher=qc.MatcherSpec(
|
||||
function_name=qc.FunctionNameMatcherSpec(
|
||||
regex='.*invalid_function_name.*'
|
||||
),
|
||||
),
|
||||
method=qc.Method(no_quantization={}),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
quantization.quantize_saved_model(
|
||||
self._input_saved_model_path,
|
||||
self._output_saved_model_path,
|
||||
config,
|
||||
)
|
||||
|
||||
input_data = ops.convert_to_tensor(random_tensor_gen_fn())
|
||||
expected_outputs = model.matmul(input_data)
|
||||
|
||||
root = load.load(self._output_saved_model_path)
|
||||
self.assertCountEqual(root.signatures.keys(), {'serving_default'})
|
||||
|
||||
new_outputs = root.signatures['serving_default'](
|
||||
input_tensor=ops.convert_to_tensor(input_data)
|
||||
)
|
||||
|
||||
# Tests that the quantized graph outputs similar values. They also shouldn't
|
||||
# be exactly the same. Indirectly proves that the `FunctionNameMatcherSpec`
|
||||
# with regex '.*invalid_function_name.*' did not match the quantizable unit.
|
||||
self.assertAllClose(new_outputs, expected_outputs, rtol=0.04)
|
||||
self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
@@ -63,10 +63,68 @@ message PipelineConfig {
|
||||
optional bool unpack_quantized_types = 1;
|
||||
}
|
||||
|
||||
// A quantization method representing "do not quantize". Mostly used for
|
||||
// denylisting quantizable units from quantization.
|
||||
message NoQuantization {}
|
||||
|
||||
// Represents a matching method that matches quantizable units by lifted
|
||||
// functions' names.
|
||||
message FunctionNameMatcherSpec {
|
||||
// Regular expression to match lifted functions' names. Underlying regex
|
||||
// engine uses re2, which accepts a subset of PCRE. See
|
||||
// https://github.com/google/re2/wiki/Syntax for details.
|
||||
string regex = 1;
|
||||
}
|
||||
|
||||
// Matcher specification for identifying quantizable units.
|
||||
message MatcherSpec {
|
||||
// Matches lifted functions by their names.
|
||||
FunctionNameMatcherSpec function_name = 1;
|
||||
}
|
||||
|
||||
// Specifies how to quantize matched quantizable units.
|
||||
message Method {
|
||||
NoQuantization no_quantization = 1;
|
||||
}
|
||||
|
||||
// A QuantizationSpec is essentially a (matcher spec, quantization method) pair,
|
||||
// where the matcher spec is used to identify quantizable units and the
|
||||
// quantization method specifies what type of quantization to apply on the
|
||||
// matched quantizable units.
|
||||
// Next ID: 3
|
||||
message QuantizationSpec {
|
||||
// Configures matchers for identifying quantizable units. Matched quantizable
|
||||
// units will be quantized according to `method`.
|
||||
MatcherSpec matcher = 1;
|
||||
|
||||
// Specifies how to quantize the matched quantizable units.
|
||||
Method method = 2;
|
||||
}
|
||||
|
||||
// Quantization specifications. A simple wrapper around a sequence of
|
||||
// `QuantizationSpec`s so that specs can be easily passed around or represented
|
||||
// as a textproto.
|
||||
// Next ID: 2
|
||||
message QuantizationSpecs {
|
||||
// List of `QuantizationSpec`s. Later spec in the sequence takes precedence.
|
||||
//
|
||||
// NOTE: Tie-breaking mechanism is not yet supported. Providing multiple
|
||||
// `QuantizationSpec` with conflicting quantizable units may result in
|
||||
// undefined behavior.
|
||||
// TODO: b/307620778 - Support tie-breaking for conflicting specs.
|
||||
repeated QuantizationSpec specs = 1;
|
||||
}
|
||||
|
||||
// Quantization configuration for StableHLO Quantizer. This is the primary
|
||||
// message containing all configurable options.
|
||||
// Next ID: 4
|
||||
// Next ID: 5
|
||||
message QuantizationConfig {
|
||||
// Config presets provide predefined popular or common quantization specs.
|
||||
// Lightweight users may choose one of the presets for quick experiments. Each
|
||||
// preset is completely represented by `QuantizationSpecs`. When extra entries
|
||||
// in `QuantizationSpecs` are provided along with a preset, then the preset
|
||||
// will be overridden for the quantizable units matched by those additional
|
||||
// `QuantizationSpec`s.
|
||||
oneof preset {
|
||||
// Performs best-effort static-range post-training quantization (PTQ).
|
||||
StaticRangePtqPreset static_range_ptq_preset = 1;
|
||||
@@ -77,4 +135,6 @@ message QuantizationConfig {
|
||||
|
||||
// Configures the graph transformation pipeline for quantization.
|
||||
PipelineConfig pipeline_config = 3;
|
||||
|
||||
QuantizationSpecs specs = 4;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs \
|
||||
// RUN: -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: @main
|
||||
func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
|
||||
%0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
|
||||
%1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
|
||||
return %1 : tensor<1x1x64xf32>
|
||||
}
|
||||
// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp
|
||||
// is missing attributes required for quantization.
|
||||
|
||||
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
|
||||
// CHECK-SAME: {_entry_function = @composite_dot_general_fn_1, {{.*}}}
|
||||
// CHECK-NOT: _original_entry_function
|
||||
// CHECK-NOT: _tfl_quant_trait
|
||||
// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: private @composite_dot_general_fn_1
|
||||
// CHECK-NOT: tf_quant.composite_function
|
||||
// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1
|
||||
// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32>
|
||||
// CHECK: }
|
||||
Reference in New Issue
Block a user