Fix application of JIT compiler plugins

1) Restore some key logic lost when landing cl/707770943, in compiled_model.cpp:122
2) Don't abort CompiledModel creation if the runtime fails to apply compiler plugins, rather issue warnings
3) Log the list of compiler plugins that were successfully applied

PiperOrigin-RevId: 713420589
This commit is contained in:
A. Unique TensorFlower
2025-01-08 14:36:27 -08:00
committed by TensorFlower Gardener
parent 805779b75a
commit df7525018e
8 changed files with 89 additions and 27 deletions

View File

@@ -45,7 +45,7 @@ TEST(CompiledModelTest, Basic) {
LiteRtCompiledModel compiled_model;
ASSERT_EQ(
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorNone, &compiled_model),
LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorCpu, &compiled_model),
kLiteRtStatusOk);
LiteRtSubgraph subgraph;

View File

@@ -69,7 +69,7 @@ class CompiledModel
// returned object.
static Expected<CompiledModel> Create(
litert::Model& model,
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorNone) {
LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorCpu) {
LiteRtCompiledModel compiled_model;
if (auto status = LiteRtCreateCompiledModel(
model.Get(), compilation_options, &compiled_model);

View File

@@ -41,6 +41,8 @@ cc_library(
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin",
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],

View File

@@ -24,6 +24,9 @@
#include <vector>
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/lite/experimental/litert/c/litert_any.h"
@@ -291,6 +294,16 @@ CompilerPlugin::~CompilerPlugin() {
}
}
std::string CompilerPlugin::DebugString() const {
std::string version_str = "?";
if (auto version = ApiVersion(); version) {
version_str = absl::StrFormat("%d.%d.%d", version->major, version->minor,
version->patch);
}
return absl::StrFormat("%s compiler plugin (ver %s)", SocManufacturer(),
version_str);
}
Expected<LiteRtApiVersion> CompilerPlugin::ApiVersion() const {
LiteRtApiVersion api_version;
LITERT_EXPECT_OK(plugin_api_.get_compiler_plugin_version(&api_version));
@@ -426,7 +439,7 @@ Expected<void> ApplyPlugin(CompilerPlugin& compiler_plugin, LiteRtModelT& model,
return {};
}
Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
Expected<ApplyPluginsResult> ApplyPlugins(
LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators) {
auto environment = litert::internal::Environment::Instance();
if (!environment) {
@@ -448,13 +461,25 @@ Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
if (!compiler_plugins) {
return compiler_plugins.Error();
}
if (compiler_plugins->empty()) {
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
"No compiler plugin found");
}
std::optional<OwningBufferRef<uint8_t>> new_flatbuffer;
OwningBufferRef<uint8_t> new_flatbuffer;
std::vector<std::string> success_messages;
std::vector<std::string> error_messages;
ApplyPluginsResult result;
result.num_applied_plugins = 0;
for (auto& compiler_plugin : *compiler_plugins) {
auto plugin_name = compiler_plugin.DebugString();
auto plugin_supported_hardware = compiler_plugin.SupportedHardware();
if (!plugin_supported_hardware) {
return plugin_supported_hardware.Error();
error_messages.push_back(absl::StrCat(
plugin_name, " ", plugin_supported_hardware.Error().Message()));
continue;
}
if (*plugin_supported_hardware & selected_hw_accelerators) {
@@ -462,28 +487,39 @@ Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
// shouldn't be needing to serialize a model to then read it again from
// the serialized buffer when applying a compiler plugin.
if (auto status = ApplyPlugin(compiler_plugin, *model); !status) {
return status.Error();
error_messages.push_back(
absl::StrCat(plugin_name, " ", status.Error().Message()));
continue;
}
auto serialized_model =
litert::internal::SerializeModel(std::move(*model));
if (!serialized_model) {
return serialized_model.Error();
error_messages.push_back(
absl::StrCat(plugin_name, " ", serialized_model.Error().Message()));
continue;
}
auto new_model = litert::Model::CreateFromBuffer(*serialized_model);
if (!new_model) {
return new_model.Error();
error_messages.push_back(
absl::StrCat(plugin_name, " ", new_model.Error().Message()));
continue;
}
new_flatbuffer = std::move(*serialized_model);
*model = std::move(*new_model->Get());
success_messages.push_back(absl::StrCat(plugin_name));
result.num_applied_plugins++;
}
}
if (!new_flatbuffer.has_value()) {
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
"No applicable compiler plugin found");
}
result.new_flatbuffer = std::move(new_flatbuffer);
result.success_message = absl::StrJoin(success_messages, ", ");
result.error_message = absl::StrJoin(error_messages, ", ");
return *new_flatbuffer;
return result;
}
} // namespace litert::internal

View File

@@ -16,6 +16,7 @@
#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_
#include <string>
#include <tuple>
#include <vector>
#include "absl/strings/string_view.h"
@@ -72,6 +73,8 @@ class CompiledResult {
// Wraps vendor compiler plugin.
class CompilerPlugin {
public:
std::string DebugString() const;
// Get the compiler plugin's API version.
Expected<LiteRtApiVersion> ApiVersion() const;
@@ -150,9 +153,19 @@ Expected<void> ApplyPlugin(
Serialization serialization = Serialization::kAppend);
// Apply all available plugins providing the selected HW accelerators to the
// given model, modify the model accordingly, and return a new flatbuffer
// backing the modified model.
Expected<OwningBufferRef<uint8_t>> ApplyPlugins(
// given model, modify the model accordingly, and return (1) the number of
// compiler plugins succesfully applied, (2) a new flatbuffer backing the
// modified model, (3) a string listing the compiler plugins that were
// succesfully applied, and (4) a string listing the compiler plugins that
// failed to apply with an associated error message.
struct ApplyPluginsResult {
size_t num_applied_plugins;
OwningBufferRef<uint8_t> new_flatbuffer;
std::string success_message;
std::string error_message;
};
Expected<ApplyPluginsResult> ApplyPlugins(
LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators);
} // namespace litert::internal

View File

@@ -94,14 +94,23 @@ Expected<LiteRtCompiledModelT::Ptr> LiteRtCompiledModelT::Create(
std::optional<OwningBufferRef<uint8_t>> new_flatbuffer;
// TODO: b/379317134 - Support other delegates with compilation options.
if (compilation_options != kLiteRtHwAccelatorNone) {
LITERT_LOG(LITERT_INFO, "Applying compiler plugins");
if (auto flatbuffer =
LITERT_LOG(LITERT_INFO, "Applying compiler plugins...");
if (auto result =
litert::internal::ApplyPlugins(model, compilation_options);
!flatbuffer) {
LITERT_LOG(LITERT_ERROR, "Failed to applying compiler plugins");
return flatbuffer.Error();
!result) {
LITERT_LOG(LITERT_WARNING, "Failed to apply compiler plugins: %s",
result.Error().Message().data());
} else {
new_flatbuffer = *flatbuffer;
if (result->num_applied_plugins > 0) {
LITERT_LOG(LITERT_INFO, "Successfully applied %d compiler plugins: %s",
result->num_applied_plugins,
result->success_message.c_str());
new_flatbuffer = std::move(result->new_flatbuffer);
}
if (!result->error_message.empty()) {
LITERT_LOG(LITERT_WARNING, "Some compiler plugins failed to apply: %s",
result->error_message.c_str());
}
}
}
@@ -109,8 +118,11 @@ Expected<LiteRtCompiledModelT::Ptr> LiteRtCompiledModelT::Create(
size_t model_buffer_size = 0;
// The following code gets the original FB pointer from LiteRtModel.
// TODO b/383120429 - Use a better way of getting the FB pointer.
auto init_model_buffer = detail::GetTflInitFlatbuffer(*model);
if (init_model_buffer.Size() != 0) {
if (new_flatbuffer) {
model_buffer = reinterpret_cast<const char*>(new_flatbuffer->Data());
model_buffer_size = new_flatbuffer->Size();
} else if (auto init_model_buffer = detail::GetTflInitFlatbuffer(*model);
init_model_buffer.Size() != 0) {
// Use the saved the original FB pointer when the LiteRtModel was created
// from a buffer.
model_buffer = init_model_buffer.StrData();

View File

@@ -137,7 +137,7 @@ TEST(CompiledModelTest, Basic) {
ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk);
auto res_compiled_model =
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorNone);
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu);
ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel: "
<< res_compiled_model.Error().Message();
auto& compiled_model = **res_compiled_model;
@@ -216,7 +216,7 @@ TEST(CompiledModelTest, UseAhwbBuffer) {
ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk);
auto res_compiled_model =
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorNone);
LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu);
ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel";
auto& compiled_model = **res_compiled_model;

View File

@@ -82,7 +82,6 @@ Expected<std::pair<NeuronModelPtr, NeuronCompilationPtr>> LoadFromDlaBytecode(
const litert::mediatek::NeuronAdapter& neuron_adapter,
const void* bytecode_addr, size_t bytecode_size, int num_inputs,
int num_outputs) {
LITERT_LOG(LITERT_INFO, "Creating model...");
Expected<NeuronModelPtr> model = neuron_adapter.CreateModel();
if (!model) {
return model.Error();