mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Fix application of JIT compiler plugins
1) Restore some key logic lost when landing cl/707770943, in compiled_model.cpp:122 2) Don't abort CompiledModel creation if the runtime fails to apply compiler plugins, rather issue warnings 3) Log the list of compiler plugins that were successfully applied PiperOrigin-RevId: 713420589
This commit is contained in:
committed by
TensorFlower Gardener
parent
805779b75a
commit
df7525018e
@@ -45,7 +45,7 @@ TEST(CompiledModelTest, Basic) {
|
||||
|
||||
LiteRtCompiledModel compiled_model;
|
||||
ASSERT_EQ(
|
||||
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorNone, &compiled_model),
|
||||
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorCpu, &compiled_model),
|
||||
kLiteRtStatusOk);
|
||||
|
||||
LiteRtSubgraph subgraph;
|
||||
|
||||
@@ -69,7 +69,7 @@ class CompiledModel
|
||||
// returned object.
|
||||
static Expected<CompiledModel> Create(
|
||||
litert::Model& model,
|
||||
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorNone) {
|
||||
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorCpu) {
|
||||
LiteRtCompiledModel compiled_model;
|
||||
if (auto status = LiteRtCreateCompiledModel(
|
||||
model.Get(), compilation_options, &compiled_model);
|
||||
|
||||
@@ -41,6 +41,8 @@ cc_library(
|
||||
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin",
|
||||
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api",
|
||||
"@com_google_absl//absl/log:absl_check",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
||||
@@ -24,6 +24,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/absl_check.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/experimental/litert/c/litert_any.h"
|
||||
@@ -291,6 +294,16 @@ CompilerPlugin::~CompilerPlugin() {
|
||||
}
|
||||
}
|
||||
|
||||
std::string CompilerPlugin::DebugString() const {
|
||||
std::string version_str = "?";
|
||||
if (auto version = ApiVersion(); version) {
|
||||
version_str = absl::StrFormat("%d.%d.%d", version->major, version->minor,
|
||||
version->patch);
|
||||
}
|
||||
return absl::StrFormat("%s compiler plugin (ver %s)", SocManufacturer(),
|
||||
version_str);
|
||||
}
|
||||
|
||||
Expected<LiteRtApiVersion> CompilerPlugin::ApiVersion() const {
|
||||
LiteRtApiVersion api_version;
|
||||
LITERT_EXPECT_OK(plugin_api_.get_compiler_plugin_version(&api_version));
|
||||
@@ -426,7 +439,7 @@ Expected<void> ApplyPlugin(CompilerPlugin& compiler_plugin, LiteRtModelT& model,
|
||||
return {};
|
||||
}
|
||||
|
||||
Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
|
||||
Expected<ApplyPluginsResult> ApplyPlugins(
|
||||
LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators) {
|
||||
auto environment = litert::internal::Environment::Instance();
|
||||
if (!environment) {
|
||||
@@ -448,13 +461,25 @@ Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
|
||||
if (!compiler_plugins) {
|
||||
return compiler_plugins.Error();
|
||||
}
|
||||
if (compiler_plugins->empty()) {
|
||||
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
|
||||
"No compiler plugin found");
|
||||
}
|
||||
|
||||
std::optional<OwningBufferRef<uint8_t>> new_flatbuffer;
|
||||
OwningBufferRef<uint8_t> new_flatbuffer;
|
||||
std::vector<std::string> success_messages;
|
||||
std::vector<std::string> error_messages;
|
||||
|
||||
ApplyPluginsResult result;
|
||||
result.num_applied_plugins = 0;
|
||||
for (auto& compiler_plugin : *compiler_plugins) {
|
||||
auto plugin_name = compiler_plugin.DebugString();
|
||||
|
||||
auto plugin_supported_hardware = compiler_plugin.SupportedHardware();
|
||||
if (!plugin_supported_hardware) {
|
||||
return plugin_supported_hardware.Error();
|
||||
error_messages.push_back(absl::StrCat(
|
||||
plugin_name, " ", plugin_supported_hardware.Error().Message()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (*plugin_supported_hardware & selected_hw_accelerators) {
|
||||
@@ -462,28 +487,39 @@ Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
|
||||
// shouldn't be needing to serialize a model to then read it again from
|
||||
// the serialized buffer when applying a compiler plugin.
|
||||
if (auto status = ApplyPlugin(compiler_plugin, *model); !status) {
|
||||
return status.Error();
|
||||
error_messages.push_back(
|
||||
absl::StrCat(plugin_name, " ", status.Error().Message()));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto serialized_model =
|
||||
litert::internal::SerializeModel(std::move(*model));
|
||||
if (!serialized_model) {
|
||||
return serialized_model.Error();
|
||||
error_messages.push_back(
|
||||
absl::StrCat(plugin_name, " ", serialized_model.Error().Message()));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_model = litert::Model::CreateFromBuffer(*serialized_model);
|
||||
if (!new_model) {
|
||||
return new_model.Error();
|
||||
error_messages.push_back(
|
||||
absl::StrCat(plugin_name, " ", new_model.Error().Message()));
|
||||
continue;
|
||||
}
|
||||
|
||||
new_flatbuffer = std::move(*serialized_model);
|
||||
*model = std::move(*new_model->Get());
|
||||
|
||||
success_messages.push_back(absl::StrCat(plugin_name));
|
||||
result.num_applied_plugins++;
|
||||
}
|
||||
}
|
||||
|
||||
if (!new_flatbuffer.has_value()) {
|
||||
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
|
||||
"No applicable compiler plugin found");
|
||||
}
|
||||
result.new_flatbuffer = std::move(new_flatbuffer);
|
||||
result.success_message = absl::StrJoin(success_messages, ", ");
|
||||
result.error_message = absl::StrJoin(error_messages, ", ");
|
||||
|
||||
return *new_flatbuffer;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace litert::internal
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
@@ -72,6 +73,8 @@ class CompiledResult {
|
||||
// Wraps vendor compiler plugin.
|
||||
class CompilerPlugin {
|
||||
public:
|
||||
std::string DebugString() const;
|
||||
|
||||
// Get the compiler plugin's API version.
|
||||
Expected<LiteRtApiVersion> ApiVersion() const;
|
||||
|
||||
@@ -150,9 +153,19 @@ Expected<void> ApplyPlugin(
|
||||
Serialization serialization = Serialization::kAppend);
|
||||
|
||||
// Apply all available plugins providing the selected HW accelerators to the
|
||||
// given model, modify the model accordingly, and return a new flatbuffer
|
||||
// backing the modified model.
|
||||
Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
|
||||
// given model, modify the model accordingly, and return (1) the number of
|
||||
// compiler plugins succesfully applied, (2) a new flatbuffer backing the
|
||||
// modified model, (3) a string listing the compiler plugins that were
|
||||
// succesfully applied, and (4) a string listing the compiler plugins that
|
||||
// failed to apply with an associated error message.
|
||||
struct ApplyPluginsResult {
|
||||
size_t num_applied_plugins;
|
||||
OwningBufferRef<uint8_t> new_flatbuffer;
|
||||
std::string success_message;
|
||||
std::string error_message;
|
||||
};
|
||||
|
||||
Expected<ApplyPluginsResult> ApplyPlugins(
|
||||
LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators);
|
||||
|
||||
} // namespace litert::internal
|
||||
|
||||
@@ -94,14 +94,23 @@ Expected<LiteRtCompiledModelT::Ptr> LiteRtCompiledModelT::Create(
|
||||
std::optional<OwningBufferRef<uint8_t>> new_flatbuffer;
|
||||
// TODO: b/379317134 - Support other delegates with compilation options.
|
||||
if (compilation_options != kLiteRtHwAccelatorNone) {
|
||||
LITERT_LOG(LITERT_INFO, "Applying compiler plugins");
|
||||
if (auto flatbuffer =
|
||||
LITERT_LOG(LITERT_INFO, "Applying compiler plugins...");
|
||||
if (auto result =
|
||||
litert::internal::ApplyPlugins(model, compilation_options);
|
||||
!flatbuffer) {
|
||||
LITERT_LOG(LITERT_ERROR, "Failed to applying compiler plugins");
|
||||
return flatbuffer.Error();
|
||||
!result) {
|
||||
LITERT_LOG(LITERT_WARNING, "Failed to apply compiler plugins: %s",
|
||||
result.Error().Message().data());
|
||||
} else {
|
||||
new_flatbuffer = *flatbuffer;
|
||||
if (result->num_applied_plugins > 0) {
|
||||
LITERT_LOG(LITERT_INFO, "Successfully applied %d compiler plugins: %s",
|
||||
result->num_applied_plugins,
|
||||
result->success_message.c_str());
|
||||
new_flatbuffer = std::move(result->new_flatbuffer);
|
||||
}
|
||||
if (!result->error_message.empty()) {
|
||||
LITERT_LOG(LITERT_WARNING, "Some compiler plugins failed to apply: %s",
|
||||
result->error_message.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,8 +118,11 @@ Expected<LiteRtCompiledModelT::Ptr> LiteRtCompiledModelT::Create(
|
||||
size_t model_buffer_size = 0;
|
||||
// The following code gets the original FB pointer from LiteRtModel.
|
||||
// TODO b/383120429 - Use a better way of getting the FB pointer.
|
||||
auto init_model_buffer = detail::GetTflInitFlatbuffer(*model);
|
||||
if (init_model_buffer.Size() != 0) {
|
||||
if (new_flatbuffer) {
|
||||
model_buffer = reinterpret_cast<const char*>(new_flatbuffer->Data());
|
||||
model_buffer_size = new_flatbuffer->Size();
|
||||
} else if (auto init_model_buffer = detail::GetTflInitFlatbuffer(*model);
|
||||
init_model_buffer.Size() != 0) {
|
||||
// Use the saved the original FB pointer when the LiteRtModel was created
|
||||
// from a buffer.
|
||||
model_buffer = init_model_buffer.StrData();
|
||||
|
||||
@@ -137,7 +137,7 @@ TEST(CompiledModelTest, Basic) {
|
||||
ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk);
|
||||
|
||||
auto res_compiled_model =
|
||||
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorNone);
|
||||
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu);
|
||||
ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel: "
|
||||
<< res_compiled_model.Error().Message();
|
||||
auto& compiled_model = **res_compiled_model;
|
||||
@@ -216,7 +216,7 @@ TEST(CompiledModelTest, UseAhwbBuffer) {
|
||||
ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk);
|
||||
|
||||
auto res_compiled_model =
|
||||
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorNone);
|
||||
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu);
|
||||
ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel";
|
||||
auto& compiled_model = **res_compiled_model;
|
||||
|
||||
|
||||
@@ -82,7 +82,6 @@ Expected<std::pair<NeuronModelPtr, NeuronCompilationPtr>> LoadFromDlaBytecode(
|
||||
const litert::mediatek::NeuronAdapter& neuron_adapter,
|
||||
const void* bytecode_addr, size_t bytecode_size, int num_inputs,
|
||||
int num_outputs) {
|
||||
LITERT_LOG(LITERT_INFO, "Creating model...");
|
||||
Expected<NeuronModelPtr> model = neuron_adapter.CreateModel();
|
||||
if (!model) {
|
||||
return model.Error();
|
||||
|
||||
Reference in New Issue
Block a user