mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Create option to allow tensorflow::Tensor objects to be imported as DenseResourceElementsAttr during TF V1/V2 saved models import to MLIR Module.
PiperOrigin-RevId: 713491382
This commit is contained in:
committed by
TensorFlower Gardener
parent
ccaef81b5c
commit
35fbbd0aa7
@@ -1,3 +1,4 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test=import-variables-as-dense-resources=true -split-input-file %s | FileCheck --check-prefix=CheckWithDense %s --dump-input=fail
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
|
||||
@@ -15,11 +16,23 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: func @serving_default(
|
||||
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/kernel"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32>
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/bias"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
|
||||
// CheckWithDense: func @serving_default(
|
||||
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -49,8 +62,10 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: func @f(
|
||||
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
@@ -58,6 +73,20 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
// CHECK: func @f2(
|
||||
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/kernel"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32>
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/bias"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
|
||||
// CheckWithDense: func @f(
|
||||
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CheckWithDense: func @f2(
|
||||
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -75,9 +104,21 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: value = dense<0.000000e+00>
|
||||
// CHECK: func @serving_default(
|
||||
// CHECK: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/kernel"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32>
|
||||
// CheckWithDense: "tf_saved_model.global_tensor"()
|
||||
// CheckWithDense: sym_name = "dense/bias"
|
||||
// CheckWithDense: value = dense_resource<dense_elements_f32_1>
|
||||
// CheckWithDense: func @serving_default(
|
||||
// CheckWithDense: %arg0: tensor<!tf_type.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CheckWithDense: %arg1: tensor<!tf_type.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
@@ -64,7 +64,8 @@ constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
|
||||
|
||||
LogicalResult LiftVariablesFromSession(
|
||||
ModuleOp module, Session* session,
|
||||
const SmallSet<StringRef, 4>& resource_names) {
|
||||
const SmallSet<StringRef, 4>& resource_names,
|
||||
bool import_variables_as_dense_resources) {
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
|
||||
if (!session) return module.emitOpError() << "no session provided";
|
||||
@@ -127,11 +128,13 @@ LogicalResult LiftVariablesFromSession(
|
||||
const Tensor& tensor = std::get<1>(iter);
|
||||
|
||||
// Create tensor attribute for this variable.
|
||||
absl::StatusOr<ElementsAttr> tensor_attr_or =
|
||||
ConvertTensor(tensor, &builder);
|
||||
absl::StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(
|
||||
tensor, &builder,
|
||||
/*convert_to_dense_resource=*/import_variables_as_dense_resources);
|
||||
if (!tensor_attr_or.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to convert tensor (name: " << name.str() << ")";
|
||||
<< "failed to convert tensor (name: " << name.str() << ")- "
|
||||
<< tensor_attr_or.status().ToString();
|
||||
}
|
||||
ElementsAttr tensor_attr = tensor_attr_or.value();
|
||||
|
||||
@@ -146,7 +149,8 @@ LogicalResult LiftVariablesFromSession(
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
LogicalResult LiftVariables(ModuleOp module, Session* session,
|
||||
bool import_variables_as_dense_resources) {
|
||||
MLIRContext* context = module.getContext();
|
||||
mlir::Builder builder(context);
|
||||
StringAttr resource_name_id = builder.getStringAttr(kResourceNameArgAttr);
|
||||
@@ -175,7 +179,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
|
||||
if (resource_names.empty()) return success();
|
||||
|
||||
if (failed(LiftVariablesFromSession(module, session, resource_names)))
|
||||
if (failed(LiftVariablesFromSession(module, session, resource_names,
|
||||
/*import_variables_as_dense_resources=*/
|
||||
import_variables_as_dense_resources)))
|
||||
return failure();
|
||||
|
||||
// Now that we have all global tensors created, we set the corresponding
|
||||
|
||||
@@ -26,7 +26,8 @@ namespace tf_saved_model {
|
||||
|
||||
// Creates GlobalTensorOp for each variable from function arguments and converts
|
||||
// them to the corresponding saved model arguments.
|
||||
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
|
||||
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session,
|
||||
bool import_variables_as_dense_resources = false);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
@@ -16,13 +16,13 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/fake_session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
namespace tf_test {
|
||||
namespace {
|
||||
using ::tensorflow::Session;
|
||||
|
||||
@@ -39,7 +39,9 @@ class LiftVariablesTestPass
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(tf_saved_model::LiftVariables(module, session_)))
|
||||
if (failed(tf_saved_model::LiftVariables(
|
||||
module, session_, /*import_variables_as_dense_resources=*/
|
||||
import_variables_as_dense_resources_)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
@@ -64,18 +66,17 @@ class LiftVariablesInvalidSessionTestPass
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tf_saved_model
|
||||
} // namespace tf_test
|
||||
|
||||
namespace tf_test {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesTestPass() {
|
||||
return std::make_unique<tf_saved_model::LiftVariablesTestPass>();
|
||||
return std::make_unique<tf_test::LiftVariablesTestPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateLiftVariablesInvalidSessionTestPass() {
|
||||
return std::make_unique<
|
||||
tf_saved_model::LiftVariablesInvalidSessionTestPass>();
|
||||
return std::make_unique<tf_test::LiftVariablesInvalidSessionTestPass>();
|
||||
}
|
||||
|
||||
} // namespace tf_test
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::Tensor;
|
||||
|
||||
// FakeSession is for testing only.
|
||||
class FakeSession : public tensorflow::Session {
|
||||
public:
|
||||
FakeSession() {}
|
||||
~FakeSession() override = default;
|
||||
|
||||
Status Create(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
Status Extend(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status Close() override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* response) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status LocalDeviceManager(
|
||||
const tensorflow::DeviceMgr** deviceMgrPtr) override {
|
||||
// This method returns a null device manager without making an error.
|
||||
// Users of this method will be notified since it will have a fake data.
|
||||
*deviceMgrPtr = nullptr;
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
|
||||
outputs, &run_metadata);
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata) override {
|
||||
return Run(run_options, inputs, output_names, target_nodes, outputs,
|
||||
run_metadata, tensorflow::thread::ThreadPoolOptions());
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata,
|
||||
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
|
||||
override {
|
||||
for (const std::string& output_name : output_names) {
|
||||
Tensor output;
|
||||
if (output_name == "dense/bias") {
|
||||
Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50}));
|
||||
t.flat<float>().setZero();
|
||||
outputs->push_back(t);
|
||||
} else if (output_name == "dense/kernel") {
|
||||
Tensor t =
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50}));
|
||||
t.flat<float>().setZero();
|
||||
outputs->push_back(t);
|
||||
} else {
|
||||
// Create a scalar float tensor.
|
||||
Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({}));
|
||||
t.flat<float>()(0) = 1.0f;
|
||||
outputs->push_back(t);
|
||||
}
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesTestPass
|
||||
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
LiftVariablesTestPass() { session_ = new FakeSession(); }
|
||||
|
||||
~LiftVariablesTestPass() override { delete session_; }
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Session* session_;
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesInvalidSessionTestPass
|
||||
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
// Pass an invalid session argument, which is a nullptr.
|
||||
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
@@ -81,6 +81,11 @@ def LiftVariablesInvalidSessionTestPass : Pass<"tf-saved-model-lift-variables-in
|
||||
def LiftVariablesTestPass : Pass<"tf-saved-model-lift-variables-test", "ModuleOp"> {
|
||||
let summary = "Lift variables and save them as global tensors";
|
||||
let constructor = "mlir::tf_test::CreateLiftVariablesTestPass()";
|
||||
|
||||
let options = [
|
||||
Option<"import_variables_as_dense_resources_", "import-variables-as-dense-resources", "bool", /*default=*/"false",
|
||||
"Import variables as dense resources">,
|
||||
];
|
||||
}
|
||||
|
||||
def InitializeVariablesInSessionInitializerPass : Pass<"tf-saved-model-initialize-variables-in-session-init", "ModuleOp"> {
|
||||
|
||||
@@ -923,7 +923,11 @@ absl::Status CreateSavedModelIR(
|
||||
saved_model->variable_reader()->Lookup(checkpoint_key, &value),
|
||||
"Could not read checkpoint key from variables bundle: ",
|
||||
checkpoint_key);
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto value_attr,
|
||||
ConvertTensor(value, &builder,
|
||||
/*convert_to_dense_resource=*/
|
||||
import_options.import_variables_as_dense_resources));
|
||||
// A variable can have a partially known type, such as
|
||||
// tensor<?x27x?xf32>, even if the initializer is a specific static
|
||||
// shape.
|
||||
@@ -1610,7 +1614,8 @@ class SavedModelSignatureDefImporter {
|
||||
builder.getUnitAttr());
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiftVariables(bundle, *module, options.lift_variables,
|
||||
options.include_variables_in_initializers));
|
||||
options.include_variables_in_initializers,
|
||||
options.import_variables_as_dense_resources));
|
||||
(*module)->removeAttr("tf_saved_model.under_construction");
|
||||
|
||||
return module;
|
||||
@@ -1626,13 +1631,15 @@ class SavedModelSignatureDefImporter {
|
||||
static absl::Status LiftVariables(const SavedModelBundle& bundle,
|
||||
mlir::ModuleOp module,
|
||||
bool lift_varhandle_ops_to_args,
|
||||
bool include_variables_in_initializers);
|
||||
bool include_variables_in_initializers,
|
||||
bool import_variables_as_dense_resources);
|
||||
};
|
||||
|
||||
absl::Status SavedModelSignatureDefImporter::LiftVariables(
|
||||
const SavedModelBundle& bundle, mlir::ModuleOp module,
|
||||
const bool lift_varhandle_ops_to_args,
|
||||
const bool include_variables_in_initializers) {
|
||||
const bool include_variables_in_initializers,
|
||||
const bool import_variables_as_dense_resources) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
|
||||
mlir::PassManager pm(module.getContext());
|
||||
@@ -1662,8 +1669,8 @@ absl::Status SavedModelSignatureDefImporter::LiftVariables(
|
||||
if (mlir::failed(pm.run(module)))
|
||||
return diag_handler.Combine(
|
||||
errors::Internal("Failed to promote var handles to args."));
|
||||
if (failed(
|
||||
mlir::tf_saved_model::LiftVariables(module, bundle.GetSession())))
|
||||
if (failed(mlir::tf_saved_model::LiftVariables(
|
||||
module, bundle.GetSession(), import_variables_as_dense_resources)))
|
||||
return diag_handler.Combine(
|
||||
errors::Internal("Failed to lift variables."));
|
||||
} else {
|
||||
|
||||
@@ -49,6 +49,10 @@ struct MLIRImportOptions {
|
||||
// Load the model without restoring associated variables from disk. Enables
|
||||
// loading raw programs without checkpoints.
|
||||
bool allow_uninitialized_variables = false;
|
||||
|
||||
// If true, variables are imported as DenseResourceElementsAttr; else,
|
||||
// variables are imported as DenseElementsAttr.
|
||||
bool import_variables_as_dense_resources = false;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@@ -225,7 +225,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context,
|
||||
bool unconditionally_use_set_output_shapes) {
|
||||
bool unconditionally_use_set_output_shapes,
|
||||
bool import_variables_as_dense_resources) {
|
||||
tensorflow::SavedModelV2Bundle bundle;
|
||||
auto load_status = tensorflow::SavedModelV2Bundle::Load(
|
||||
std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
|
||||
@@ -239,6 +240,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,
|
||||
options.add_default_attributes = true;
|
||||
options.unconditionally_use_set_output_shapes =
|
||||
unconditionally_use_set_output_shapes;
|
||||
options.import_variables_as_dense_resources =
|
||||
import_variables_as_dense_resources;
|
||||
|
||||
auto module_or =
|
||||
ConvertSavedModelToMlir(&bundle, context, exported_names, options);
|
||||
|
||||
@@ -108,7 +108,8 @@ SavedModelObjectGraphToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
bool unconditionally_use_set_output_shapes = false);
|
||||
bool unconditionally_use_set_output_shapes = false,
|
||||
bool import_variables_as_dense_resources = false);
|
||||
|
||||
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
|
||||
Reference in New Issue
Block a user