From c9181ba90e8e9bbfa65785ab064216c8e16796bd Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Mon, 3 May 2021 09:07:42 -0700 Subject: [PATCH] Cherry pick 2.2 TFLite: Error out when the graph has a recurion. --- tensorflow/lite/BUILD | 1 + tensorflow/lite/core/subgraph.cc | 46 ++++++++++++++++++ tensorflow/lite/core/subgraph.h | 4 ++ tensorflow/lite/kernels/while.cc | 2 - tensorflow/lite/model_test.cc | 18 +++++++ .../lite/testdata/unsupported_recursion.bin | Bin 0 -> 600 bytes 6 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 tensorflow/lite/testdata/unsupported_recursion.bin diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index e9539d42f75..85a22251602 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -357,6 +357,7 @@ cc_test( "testdata/test_min_runtime.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", + "testdata/unsupported_recursion.bin", ], tags = [ "tflite_not_portable", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 67c8517aa45..fdd4d0adab4 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -139,6 +139,42 @@ const char* GetTFLiteOpName(const TfLiteRegistration& op_reg) { return tflite::EnumNamesBuiltinOperator()[op_reg.builtin_code]; } +// An utility test to detect if the subgraph is abused: +// 1. Detects if recursion exists in the graph (recursion is not currently +// supported. +// 2. Detects if the interpreter / subgraph is used in multiple subgraphs. +// Note: It's clearly documented that the interpreter / subgraph are not +// thread-safe. This serves as a check with possible false negatives +// unless we switch to atomic boolean flags. +class SubgraphGuard { + public: + SubgraphGuard(TfLiteContext* context, bool* is_subgraph_in_use) + : is_subgraph_in_use_(is_subgraph_in_use) { + if (*is_subgraph_in_use_) { + TF_LITE_KERNEL_LOG( + context, + "Subgraph is already in use. Using an interpreter or a subgraph in " + "multiple threads is not supported. Recursion in the graph is not " + "supported."); + status_ = kTfLiteError; + } else { + *is_subgraph_in_use_ = true; + } + } + ~SubgraphGuard() { + // If tht original status was OK, recover the boolean flag. + if (status_ == kTfLiteOk) { + *is_subgraph_in_use_ = false; + } + } + + TfLiteStatus status() const { return status_; } + + private: + TfLiteStatus status_ = kTfLiteOk; + bool* is_subgraph_in_use_; +}; + } // namespace // A trivial implementation of GraphInfo around the Interpreter. @@ -630,6 +666,7 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims, TfLiteStatus Subgraph::AllocateTensors() { TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(), "AllocateTensors"); + if (!consistent_) { ReportError("AllocateTensors() called on inconsistent model."); return kTfLiteError; @@ -653,6 +690,12 @@ TfLiteStatus Subgraph::AllocateTensors() { return kTfLiteOk; } + // Note `AllocateTensors` sometimes calls itself recursively above + // for delegates. Therefore only the logic below need to be guarded + // by `SubgraphGuard`. + SubgraphGuard guard(&context_, &is_subgraph_in_use_); + TF_LITE_ENSURE_OK(&context_, guard.status()); + next_execution_plan_index_to_prepare_ = 0; next_execution_plan_index_to_plan_allocation_ = 0; if (memory_planner_) { @@ -880,6 +923,9 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() { } TfLiteStatus Subgraph::Invoke() { + SubgraphGuard guard(&context_, &is_subgraph_in_use_); + TF_LITE_ENSURE_OK(&context_, guard.status()); + if (!consistent_) { ReportError("Invoke called on model that is not consistent."); return kTfLiteError; diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 4040121228a..b2091ed9e31 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -682,6 +682,10 @@ class Subgraph { // A map of resources. Owned by interpreter and shared by multiple subgraphs. resource::ResourceMap* resources_ = nullptr; + + // Whether the subgraph is currently in use (e.g. running the `Invoke` + // or `AllocateTensors` functions). + bool is_subgraph_in_use_ = false; }; } // namespace tflite diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index e9bc3388693..81b6c0c6634 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -132,8 +132,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* subgraphs = this_subgraph->GetSubgraphs(); TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size()); TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size()); - TF_LITE_ENSURE(context, - op_data->cond_subgraph_index != op_data->body_subgraph_index); Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index b9efdf676a8..9efe76d15b4 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -442,6 +442,24 @@ TEST(BasicFlatBufferModel, TestParseModelWithSparseTensor) { } // TODO(b/150072943): Add malformed model with sparse tensor tests. +// Recursion & reentrant are not supported in TFLite. +// The test ensures it fails gracefullly instead of crashing with +// a stack overflow. +TEST(BasicFlatBufferModel, TestUnsupportedRecursion) { + const auto model_path = + "tensorflow/lite/testdata/unsupported_recursion.bin"; + + std::unique_ptr model = + FlatBufferModel::BuildFromFile(model_path); + ASSERT_NE(model, nullptr); + + tflite::ops::builtin::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + std::unique_ptr interpreter; + ASSERT_EQ(builder(&interpreter), kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk); +} // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, diff --git a/tensorflow/lite/testdata/unsupported_recursion.bin b/tensorflow/lite/testdata/unsupported_recursion.bin new file mode 100644 index 0000000000000000000000000000000000000000..525c5383ab4ef6283d687aeb4004b38a8981773a GIT binary patch literal 600 zcmZ9Ky-Nc@5XE0KBoSjwAt)rp6_yhG#1vKvijYXufSnM$!=S-~Aeu6$NGcVy43ffD zv=J10KN>shH2xi)@7-ZXCm}8pCn_Msmv&~WA#tZJ0cCz&tN{cxMIz@J9|!b*ReAQF{b^11cvA(nOQj# zCZ44>NsC+&i(EIS4_B9;x1Szw9@@qB(roJXt*H| z)K!b}Cy_?#d+(@U2g3_29YX)Bc3Md1?cfLgjkUf;)Q?N{Qaur!JZFqWInI}A1=p=W z{yUMC1Hqutn9*~iuSWZfs%Pq$ZsT{)bY7h?O~U>M)(#XZr46U-R&1w=9*%=M1n%g| fP6NS#(KVwVMlX!&zGdTq_