Create option to allow tensorflow::Tensor objects to be imported as DenseResourceElementsAttr during TF V1/V2 saved models import to MLIR Module.

PiperOrigin-RevId: 713491382
This commit is contained in:
Vamsi Manchala
2025-01-08 18:46:09 -08:00
committed by TensorFlower Gardener
parent ccaef81b5c
commit 35fbbd0aa7
10 changed files with 91 additions and 172 deletions

View File

@@ -1,3 +1,4 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test=import-variables-as-dense-resources=true -split-input-file %s | FileCheck --check-prefix=CheckWithDense %s --dump-input=fail
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
@@ -15,11 +16,23 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: value = dense<0.000000e+00>
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: value = dense<0.000000e+00>
// CHECK: func @serving_default(
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/kernel"
// CheckWithDense: value = dense_resource<dense_elements_f32>
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/bias"
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
// CheckWithDense: func @serving_default(
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}
// -----
@@ -49,8 +62,10 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: value = dense<0.000000e+00>
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: value = dense<0.000000e+00>
// CHECK: func @f(
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
@@ -58,6 +73,20 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
// CHECK: func @f2(
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/kernel"
// CheckWithDense: value = dense_resource<dense_elements_f32>
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/bias"
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
// CheckWithDense: func @f(
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CheckWithDense: func @f2(
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}
// -----
@@ -75,9 +104,21 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: value = dense<0.000000e+00>
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: value = dense<0.000000e+00>
// CHECK: func @serving_default(
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/kernel"
// CheckWithDense: value = dense_resource<dense_elements_f32>
// CheckWithDense: "tf_saved_model.global_tensor"()
// CheckWithDense: sym_name = "dense/bias"
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
// CheckWithDense: func @serving_default(
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}

View File

@@ -64,7 +64,8 @@ constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
LogicalResult LiftVariablesFromSession(
ModuleOp module, Session* session,
const SmallSet<StringRef, 4>& resource_names) {
const SmallSet<StringRef, 4>& resource_names,
bool import_variables_as_dense_resources) {
OpBuilder builder(module.getBodyRegion());
if (!session) return module.emitOpError() << "no session provided";
@@ -127,11 +128,13 @@ LogicalResult LiftVariablesFromSession(
const Tensor& tensor = std::get<1>(iter);
// Create tensor attribute for this variable.
absl::StatusOr<ElementsAttr> tensor_attr_or =
ConvertTensor(tensor, &builder);
absl::StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(
tensor, &builder,
/*convert_to_dense_resource=*/import_variables_as_dense_resources);
if (!tensor_attr_or.ok()) {
return module.emitOpError()
<< "failed to convert tensor (name: " << name.str() << ")";
<< "failed to convert tensor (name: " << name.str() << ")- "
<< tensor_attr_or.status().ToString();
}
ElementsAttr tensor_attr = tensor_attr_or.value();
@@ -146,7 +149,8 @@ LogicalResult LiftVariablesFromSession(
} // namespace
LogicalResult LiftVariables(ModuleOp module, Session* session) {
LogicalResult LiftVariables(ModuleOp module, Session* session,
bool import_variables_as_dense_resources) {
MLIRContext* context = module.getContext();
mlir::Builder builder(context);
StringAttr resource_name_id = builder.getStringAttr(kResourceNameArgAttr);
@@ -175,7 +179,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) {
if (resource_names.empty()) return success();
if (failed(LiftVariablesFromSession(module, session, resource_names)))
if (failed(LiftVariablesFromSession(module, session, resource_names,
/*import_variables_as_dense_resources=*/
import_variables_as_dense_resources)))
return failure();
// Now that we have all global tensors created, we set the corresponding

View File

@@ -26,7 +26,8 @@ namespace tf_saved_model {
// Creates GlobalTensorOp for each variable from function arguments and converts
// them to the corresponding saved model arguments.
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session,
bool import_variables_as_dense_resources = false);
} // namespace tf_saved_model
} // namespace mlir

View File

@@ -16,13 +16,13 @@ limitations under the License.
#include <memory>
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/fake_session.h"
namespace mlir {
namespace tf_saved_model {
namespace tf_test {
namespace {
using ::tensorflow::Session;
@@ -39,7 +39,9 @@ class LiftVariablesTestPass
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(tf_saved_model::LiftVariables(module, session_)))
if (failed(tf_saved_model::LiftVariables(
module, session_, /*import_variables_as_dense_resources=*/
import_variables_as_dense_resources_)))
signalPassFailure();
}
@@ -64,18 +66,17 @@ class LiftVariablesInvalidSessionTestPass
};
} // namespace
} // namespace tf_saved_model
} // namespace tf_test
namespace tf_test {
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesTestPass() {
return std::make_unique<tf_saved_model::LiftVariablesTestPass>();
return std::make_unique<tf_test::LiftVariablesTestPass>();
}
std::unique_ptr<OperationPass<ModuleOp>>
CreateLiftVariablesInvalidSessionTestPass() {
return std::make_unique<
tf_saved_model::LiftVariablesInvalidSessionTestPass>();
return std::make_unique<tf_test::LiftVariablesInvalidSessionTestPass>();
}
} // namespace tf_test

View File

@@ -1,150 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
using ::tensorflow::DeviceMgr;
using ::tensorflow::Session;
using ::tensorflow::Status;
using ::tensorflow::Tensor;
// FakeSession is for testing only.
class FakeSession : public tensorflow::Session {
public:
FakeSession() {}
~FakeSession() override = default;
Status Create(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Extend(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Close() override {
return tensorflow::errors::Unimplemented("not available");
}
Status ListDevices(
std::vector<tensorflow::DeviceAttributes>* response) override {
return tensorflow::errors::Unimplemented("not available");
}
Status LocalDeviceManager(
const tensorflow::DeviceMgr** deviceMgrPtr) override {
// This method returns a null device manager without making an error.
// Users of this method will be notified since it will have a fake data.
*deviceMgrPtr = nullptr;
return OkStatus();
}
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs) override {
tensorflow::RunMetadata run_metadata;
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
outputs, &run_metadata);
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata) override {
return Run(run_options, inputs, output_names, target_nodes, outputs,
run_metadata, tensorflow::thread::ThreadPoolOptions());
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
override {
for (const std::string& output_name : output_names) {
Tensor output;
if (output_name == "dense/bias") {
Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50}));
t.flat<float>().setZero();
outputs->push_back(t);
} else if (output_name == "dense/kernel") {
Tensor t =
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50}));
t.flat<float>().setZero();
outputs->push_back(t);
} else {
// Create a scalar float tensor.
Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({}));
t.flat<float>()(0) = 1.0f;
outputs->push_back(t);
}
}
return OkStatus();
}
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesTestPass
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
public:
LiftVariablesTestPass() { session_ = new FakeSession(); }
~LiftVariablesTestPass() override { delete session_; }
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(LiftVariables(module, session_))) signalPassFailure();
}
private:
Session* session_;
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesInvalidSessionTestPass
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
ModuleOp module = getOperation();
// Pass an invalid session argument, which is a nullptr.
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
}
};
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_

View File

@@ -81,6 +81,11 @@ def LiftVariablesInvalidSessionTestPass : Pass<"tf-saved-model-lift-variables-in
def LiftVariablesTestPass : Pass<"tf-saved-model-lift-variables-test", "ModuleOp"> {
let summary = "Lift variables and save them as global tensors";
let constructor = "mlir::tf_test::CreateLiftVariablesTestPass()";
let options = [
Option<"import_variables_as_dense_resources_", "import-variables-as-dense-resources", "bool", /*default=*/"false",
"Import variables as dense resources">,
];
}
def InitializeVariablesInSessionInitializerPass : Pass<"tf-saved-model-initialize-variables-in-session-init", "ModuleOp"> {

View File

@@ -923,7 +923,11 @@ absl::Status CreateSavedModelIR(
saved_model->variable_reader()->Lookup(checkpoint_key, &value),
"Could not read checkpoint key from variables bundle: ",
checkpoint_key);
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
TF_ASSIGN_OR_RETURN(
auto value_attr,
ConvertTensor(value, &builder,
/*convert_to_dense_resource=*/
import_options.import_variables_as_dense_resources));
// A variable can have a partially known type, such as
// tensor<?x27x?xf32>, even if the initializer is a specific static
// shape.
@@ -1610,7 +1614,8 @@ class SavedModelSignatureDefImporter {
builder.getUnitAttr());
TF_RETURN_IF_ERROR(
LiftVariables(bundle, *module, options.lift_variables,
options.include_variables_in_initializers));
options.include_variables_in_initializers,
options.import_variables_as_dense_resources));
(*module)->removeAttr("tf_saved_model.under_construction");
return module;
@@ -1626,13 +1631,15 @@ class SavedModelSignatureDefImporter {
static absl::Status LiftVariables(const SavedModelBundle& bundle,
mlir::ModuleOp module,
bool lift_varhandle_ops_to_args,
bool include_variables_in_initializers);
bool include_variables_in_initializers,
bool import_variables_as_dense_resources);
};
absl::Status SavedModelSignatureDefImporter::LiftVariables(
const SavedModelBundle& bundle, mlir::ModuleOp module,
const bool lift_varhandle_ops_to_args,
const bool include_variables_in_initializers) {
const bool include_variables_in_initializers,
const bool import_variables_as_dense_resources) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
mlir::PassManager pm(module.getContext());
@@ -1662,8 +1669,8 @@ absl::Status SavedModelSignatureDefImporter::LiftVariables(
if (mlir::failed(pm.run(module)))
return diag_handler.Combine(
errors::Internal("Failed to promote var handles to args."));
if (failed(
mlir::tf_saved_model::LiftVariables(module, bundle.GetSession())))
if (failed(mlir::tf_saved_model::LiftVariables(
module, bundle.GetSession(), import_variables_as_dense_resources)))
return diag_handler.Combine(
errors::Internal("Failed to lift variables."));
} else {

View File

@@ -49,6 +49,10 @@ struct MLIRImportOptions {
// Load the model without restoring associated variables from disk. Enables
// loading raw programs without checkpoints.
bool allow_uninitialized_variables = false;
// If true, variables are imported as DenseResourceElementsAttr; else,
// variables are imported as DenseElementsAttr.
bool import_variables_as_dense_resources = false;
};
} // namespace tensorflow

View File

@@ -225,7 +225,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context,
bool unconditionally_use_set_output_shapes) {
bool unconditionally_use_set_output_shapes,
bool import_variables_as_dense_resources) {
tensorflow::SavedModelV2Bundle bundle;
auto load_status = tensorflow::SavedModelV2Bundle::Load(
std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
@@ -239,6 +240,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,
options.add_default_attributes = true;
options.unconditionally_use_set_output_shapes =
unconditionally_use_set_output_shapes;
options.import_variables_as_dense_resources =
import_variables_as_dense_resources;
auto module_or =
ConvertSavedModelToMlir(&bundle, context, exported_names, options);

View File

@@ -108,7 +108,8 @@ SavedModelObjectGraphToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
bool unconditionally_use_set_output_shapes = false);
bool unconditionally_use_set_output_shapes = false,
bool import_variables_as_dense_resources = false);
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the