diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 1b64a52ecef..c3dc8531bb9 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -191,6 +191,10 @@ file(GLOB_RECURSE tf_core_lib_srcs "${tensorflow_source_dir}/tensorflow/core/lib/*.h" "${tensorflow_source_dir}/tensorflow/core/lib/*.cc" "${tensorflow_source_dir}/tensorflow/core/public/*.h" + # TODO(@jart): Move StatusOr into core. + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h" ) file(GLOB tf_core_platform_srcs diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index f5a8e02a9bb..98fff9e0ae4 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -15,13 +15,11 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/schema.h" namespace tensorflow { -namespace db { namespace { class SqliteSchema { public: - explicit SqliteSchema(Sqlite* db) : db_(db) {} - ~SqliteSchema() { db_ = nullptr; } + explicit SqliteSchema(std::shared_ptr db) : db_(std::move(db)) {} /// \brief Creates Tensors table. /// @@ -371,18 +369,18 @@ class SqliteSchema { Status Run(const char* sql) { auto stmt = db_->Prepare(sql); - TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt->StepAndReset(), sql); + TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql); return Status::OK(); } private: - Sqlite* db_; + std::shared_ptr db_; }; } // namespace -Status SetupTensorboardSqliteDb(Sqlite* db) { - SqliteSchema s(db); +Status SetupTensorboardSqliteDb(std::shared_ptr db) { + SqliteSchema s(std::move(db)); TF_RETURN_IF_ERROR(s.CreateTensorsTable()); TF_RETURN_IF_ERROR(s.CreateTensorChunksTable()); TF_RETURN_IF_ERROR(s.CreateTagsTable()); @@ -408,5 +406,4 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { return Status::OK(); } -} // namespace db } // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/db/schema.h b/tensorflow/contrib/tensorboard/db/schema.h index d3a6922d94a..900c10298ce 100644 --- a/tensorflow/contrib/tensorboard/db/schema.h +++ b/tensorflow/contrib/tensorboard/db/schema.h @@ -15,19 +15,19 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ #define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ +#include + #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/db/sqlite.h" namespace tensorflow { -namespace db { /// \brief Creates TensorBoard SQLite tables and indexes. /// /// If they are already created, this has no effect. If schema /// migrations are necessary, they will be performed with logging. -Status SetupTensorboardSqliteDb(Sqlite* db); +Status SetupTensorboardSqliteDb(std::shared_ptr db); -} // namespace db } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ diff --git a/tensorflow/contrib/tensorboard/db/schema_test.cc b/tensorflow/contrib/tensorboard/db/schema_test.cc index a4302dda447..463c4e59e7e 100644 --- a/tensorflow/contrib/tensorboard/db/schema_test.cc +++ b/tensorflow/contrib/tensorboard/db/schema_test.cc @@ -20,15 +20,12 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace db { namespace { TEST(SchemaTest, SmokeTestTensorboardSchema) { - std::unique_ptr db; - TF_ASSERT_OK(Sqlite::Open(":memory:", &db)); - TF_ASSERT_OK(SetupTensorboardSqliteDb(db.get())); + auto db = Sqlite::Open(":memory:").ValueOrDie(); + TF_ASSERT_OK(SetupTensorboardSqliteDb(db)); } } // namespace -} // namespace db } // namespace tensorflow diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/sql/sqlite_query_connection.cc index a9e6ee09694..1330506d28c 100644 --- a/tensorflow/core/kernels/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/sql/sqlite_query_connection.h" + #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { @@ -29,17 +30,18 @@ Status SqliteQueryConnection::Open(const string& data_source_name, return errors::FailedPrecondition( "Failed to open query connection: Connection already opeend."); } - Status s = db::Sqlite::Open(data_source_name, &db_); + auto s = Sqlite::Open(data_source_name); if (s.ok()) { + db_ = std::move(s.ValueOrDie()); query_ = query; output_types_ = output_types; } - return s; + return s.status(); } Status SqliteQueryConnection::Close() { Status s; - s.Update(stmt_->Close()); + s.Update(stmt_.Close()); s.Update(db_->Close()); return s; } @@ -52,7 +54,7 @@ Status SqliteQueryConnection::GetNext(std::vector* out_tensors, return s; } } - Status s = stmt_->Step(end_of_sequence); + Status s = stmt_.Step(end_of_sequence); if (!*end_of_sequence) { for (int i = 0; i < column_count_; i++) { DataType dt = output_types_[i]; @@ -66,9 +68,9 @@ Status SqliteQueryConnection::GetNext(std::vector* out_tensors, Status SqliteQueryConnection::PrepareQuery() { stmt_ = db_->Prepare(query_); - Status s = stmt_->status(); + Status s = stmt_.status(); if (s.ok()) { - int column_count = stmt_->ColumnCount(); + int column_count = stmt_.ColumnCount(); if (column_count != output_types_.size()) { return errors::InvalidArgument(tensorflow::strings::Printf( "The number of columns in query (%d) must match the number of " @@ -84,40 +86,40 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry( const DataType& data_type, int column_index, Tensor* tensor) { switch (data_type) { case DT_STRING: - tensor->scalar()() = stmt_->ColumnString(column_index); + tensor->scalar()() = stmt_.ColumnString(column_index); break; case DT_INT8: tensor->scalar()() = - static_cast(stmt_->ColumnInt(column_index)); + static_cast(stmt_.ColumnInt(column_index)); break; case DT_INT16: tensor->scalar()() = - static_cast(stmt_->ColumnInt(column_index)); + static_cast(stmt_.ColumnInt(column_index)); break; case DT_INT32: tensor->scalar()() = - static_cast(stmt_->ColumnInt(column_index)); + static_cast(stmt_.ColumnInt(column_index)); break; case DT_INT64: - tensor->scalar()() = stmt_->ColumnInt(column_index); + tensor->scalar()() = stmt_.ColumnInt(column_index); break; case DT_UINT8: tensor->scalar()() = - static_cast(stmt_->ColumnInt(column_index)); + static_cast(stmt_.ColumnInt(column_index)); break; case DT_UINT16: tensor->scalar()() = - static_cast(stmt_->ColumnInt(column_index)); + static_cast(stmt_.ColumnInt(column_index)); break; case DT_BOOL: - tensor->scalar()() = stmt_->ColumnInt(column_index) != 0; + tensor->scalar()() = stmt_.ColumnInt(column_index) != 0; break; case DT_FLOAT: tensor->scalar()() = - static_cast(stmt_->ColumnDouble(column_index)); + static_cast(stmt_.ColumnDouble(column_index)); break; case DT_DOUBLE: - tensor->scalar()() = stmt_->ColumnDouble(column_index); + tensor->scalar()() = stmt_.ColumnDouble(column_index); break; // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case. default: { diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.h b/tensorflow/core/kernels/sql/sqlite_query_connection.h index b0b4737a1ea..435dd8e234c 100644 --- a/tensorflow/core/kernels/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.h @@ -42,8 +42,8 @@ class SqliteQueryConnection : public QueryConnection { // `stmt_`. void FillTensorWithResultSetEntry(const DataType& data_type, int column_index, Tensor* tensor); - std::unique_ptr db_ = nullptr; - std::unique_ptr stmt_ = nullptr; + std::shared_ptr db_ = nullptr; + SqliteStatement stmt_; int column_count_ = 0; string query_; DataTypeVector output_types_; diff --git a/tensorflow/core/lib/db/BUILD b/tensorflow/core/lib/db/BUILD index 367686c16a8..41b7af1b699 100644 --- a/tensorflow/core/lib/db/BUILD +++ b/tensorflow/core/lib/db/BUILD @@ -12,6 +12,7 @@ cc_library( srcs = ["sqlite.cc"], hdrs = ["sqlite.h"], deps = [ + "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "@sqlite_archive//:sqlite", ], diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc index 108be452a22..701655f622a 100644 --- a/tensorflow/core/lib/db/sqlite.cc +++ b/tensorflow/core/lib/db/sqlite.cc @@ -18,14 +18,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { -namespace db { /* static */ -Status Sqlite::Open(const string& uri, std::unique_ptr* db) { +xla::StatusOr> Sqlite::Open(const string& uri) { sqlite3* sqlite = nullptr; Status s = MakeStatus(sqlite3_open(uri.c_str(), &sqlite)); if (s.ok()) { - *db = std::unique_ptr(new Sqlite(sqlite)); + return std::shared_ptr(new Sqlite(sqlite)); } return s; } @@ -87,6 +86,9 @@ Sqlite::~Sqlite() { } Status Sqlite::Close() { + if (db_ == nullptr) { + return Status::OK(); + } // If Close is explicitly called, ordering must be correct. Status s = MakeStatus(sqlite3_close(db_)); if (s.ok()) { @@ -95,23 +97,42 @@ Status Sqlite::Close() { return s; } -std::unique_ptr Sqlite::Prepare(const string& sql) { +SqliteStatement Sqlite::Prepare(const string& sql) { sqlite3_stmt* stmt = nullptr; int rc = sqlite3_prepare_v2(db_, sql.c_str(), sql.size() + 1, &stmt, nullptr); - return std::unique_ptr(new SqliteStatement(stmt, rc)); + if (rc == SQLITE_OK) { + return {stmt, SQLITE_OK, std::unique_ptr(nullptr)}; + } else { + return {nullptr, rc, std::unique_ptr(new string(sql))}; + } } -SqliteStatement::SqliteStatement(sqlite3_stmt* stmt, int error) - : stmt_(stmt), error_(error) {} +Status SqliteStatement::status() const { + Status s = Sqlite::MakeStatus(error_); + if (!s.ok()) { + if (stmt_ != nullptr) { + errors::AppendToMessage(&s, sqlite3_sql(stmt_)); + } else { + errors::AppendToMessage(&s, *prepare_error_sql_); + } + } + return s; +} -SqliteStatement::~SqliteStatement() { - int rc = sqlite3_finalize(stmt_); - if (rc != SQLITE_OK) { - LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc); +void SqliteStatement::CloseOrLog() { + if (stmt_ != nullptr) { + int rc = sqlite3_finalize(stmt_); + if (rc != SQLITE_OK) { + LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc); + } + stmt_ = nullptr; } } Status SqliteStatement::Close() { + if (stmt_ == nullptr) { + return Status::OK(); + } int rc = sqlite3_finalize(stmt_); if (rc == SQLITE_OK) { stmt_ = nullptr; @@ -121,8 +142,10 @@ Status SqliteStatement::Close() { } void SqliteStatement::Reset() { - sqlite3_reset(stmt_); - sqlite3_clear_bindings(stmt_); + if (TF_PREDICT_TRUE(stmt_ != nullptr)) { + sqlite3_reset(stmt_); + sqlite3_clear_bindings(stmt_); // not nullptr friendly + } error_ = SQLITE_OK; } @@ -163,5 +186,4 @@ Status SqliteStatement::StepAndReset() { return s; } -} // namespace db } // namespace tensorflow diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h index 316e938f1b8..774852efea7 100644 --- a/tensorflow/core/lib/db/sqlite.h +++ b/tensorflow/core/lib/db/sqlite.h @@ -17,15 +17,16 @@ limitations under the License. #include #include +#include #include "sqlite3.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -namespace db { class SqliteStatement; @@ -46,7 +47,7 @@ class Sqlite { /// `file::memory:` for testing. /// /// See https://sqlite.org/c3ref/open.html - static Status Open(const string& uri, std::unique_ptr* db); + static xla::StatusOr> Open(const string& uri); /// \brief Makes tensorflow::Status for SQLite result code. /// @@ -65,7 +66,7 @@ class Sqlite { /// \brief Frees underlying SQLite object. /// /// Unlike the destructor, all SqliteStatement objects must be closed - /// beforehand. + /// beforehand. This is a no-op if already closed Status Close(); /// \brief Creates SQLite statement. @@ -74,7 +75,7 @@ class Sqlite { /// failed. It is also possible to punt the error checking to after /// the values have been binded and Step() or ExecuteWriteQuery() is /// called. - std::unique_ptr Prepare(const string& sql); + SqliteStatement Prepare(const string& sql); private: explicit Sqlite(sqlite3* db); @@ -89,21 +90,34 @@ class Sqlite { /// Instances of this class are not thread safe. class SqliteStatement { public: - /// \brief Destroys object and finalizes statement if needed. - ~SqliteStatement(); + /// \brief Constructs empty statement that should be assigned later. + SqliteStatement() : stmt_(nullptr), error_(SQLITE_OK) {} + + /// \brief Empties object and finalizes statement if needed. + ~SqliteStatement() { CloseOrLog(); } + + /// \brief Move constructor, after which should not be used. + SqliteStatement(SqliteStatement&& other); + + /// \brief Move assignment, after which should not be used. + SqliteStatement& operator=(SqliteStatement&& other); + + /// \brief Returns true if statement is not empty. + operator bool() const { return stmt_ != nullptr; } /// \brief Returns SQLite result code state. /// /// This will be SQLITE_OK unless an error happened. If multiple /// errors happened, only the first error code will be returned. - int error() { return error_; } + int error() const { return error_; } /// \brief Returns error() as a tensorflow::Status. - Status status() { return Sqlite::MakeStatus(error_); } + Status status() const; /// \brief Finalize statement object. /// - /// Please note that the destructor can also do this. + /// Please note that the destructor can also do this. This method is + /// a no-op if already closed. Status Close(); /// \brief Executes query and/or fetches next row. @@ -247,7 +261,12 @@ class SqliteStatement { private: friend Sqlite; - SqliteStatement(sqlite3_stmt* stmt, int error); // takes ownership + SqliteStatement(sqlite3_stmt* stmt, int error, + std::unique_ptr prepare_error_sql) + : stmt_(stmt), + error_(error), + prepare_error_sql_(std::move(prepare_error_sql)) {} + void CloseOrLog(); void Update(int rc) { if (TF_PREDICT_FALSE(rc != SQLITE_OK)) { @@ -268,11 +287,31 @@ class SqliteStatement { sqlite3_stmt* stmt_; int error_; + std::unique_ptr prepare_error_sql_; TF_DISALLOW_COPY_AND_ASSIGN(SqliteStatement); }; -} // namespace db +inline SqliteStatement::SqliteStatement(SqliteStatement&& other) + : stmt_(other.stmt_), + error_(other.error_), + prepare_error_sql_(std::move(other.prepare_error_sql_)) { + other.stmt_ = nullptr; + other.error_ = SQLITE_OK; +} + +inline SqliteStatement& SqliteStatement::operator=(SqliteStatement&& other) { + if (&other != this) { + CloseOrLog(); + stmt_ = other.stmt_; + error_ = other.error_; + prepare_error_sql_ = std::move(other.prepare_error_sql_); + other.stmt_ = nullptr; + other.error_ = SQLITE_OK; + } + return *this; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_LIB_DB_SQLITE_H_ diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index ce22379d97d..ba045274adc 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -24,97 +24,96 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace db { namespace { class SqliteTest : public ::testing::Test { protected: void SetUp() override { - TF_ASSERT_OK(Sqlite::Open(":memory:", &db_)); + db_ = Sqlite::Open(":memory:").ValueOrDie(); auto stmt = db_->Prepare("CREATE TABLE T (a BLOB, b BLOB)"); - TF_ASSERT_OK(stmt->StepAndReset()); + TF_ASSERT_OK(stmt.StepAndReset()); } - std::unique_ptr db_; + std::shared_ptr db_; bool is_done_; }; TEST_F(SqliteTest, InsertAndSelectInt) { auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindInt(1, 3); - stmt->BindInt(2, -7); - TF_ASSERT_OK(stmt->StepAndReset()); - stmt->BindInt(1, 123); - stmt->BindInt(2, -123); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindInt(1, 3); + stmt.BindInt(2, -7); + TF_ASSERT_OK(stmt.StepAndReset()); + stmt.BindInt(1, 123); + stmt.BindInt(2, -123); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T ORDER BY b"); - TF_ASSERT_OK(stmt->Step(&is_done_)); + TF_ASSERT_OK(stmt.Step(&is_done_)); ASSERT_FALSE(is_done_); - EXPECT_EQ(123, stmt->ColumnInt(0)); - EXPECT_EQ(-123, stmt->ColumnInt(1)); - TF_ASSERT_OK(stmt->Step(&is_done_)); + EXPECT_EQ(123, stmt.ColumnInt(0)); + EXPECT_EQ(-123, stmt.ColumnInt(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); ASSERT_FALSE(is_done_); - EXPECT_EQ(3, stmt->ColumnInt(0)); - EXPECT_EQ(-7, stmt->ColumnInt(1)); - TF_ASSERT_OK(stmt->Step(&is_done_)); + EXPECT_EQ(3, stmt.ColumnInt(0)); + EXPECT_EQ(-7, stmt.ColumnInt(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); ASSERT_TRUE(is_done_); } TEST_F(SqliteTest, InsertAndSelectDouble) { auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindDouble(1, 6.28318530); - stmt->BindDouble(2, 1.61803399); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindDouble(1, 6.28318530); + stmt.BindDouble(2, 1.61803399); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(6.28318530, stmt->ColumnDouble(0)); - EXPECT_EQ(1.61803399, stmt->ColumnDouble(1)); - EXPECT_EQ(6, stmt->ColumnInt(0)); - EXPECT_EQ(1, stmt->ColumnInt(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(6.28318530, stmt.ColumnDouble(0)); + EXPECT_EQ(1.61803399, stmt.ColumnDouble(1)); + EXPECT_EQ(6, stmt.ColumnInt(0)); + EXPECT_EQ(1, stmt.ColumnInt(1)); } TEST_F(SqliteTest, NulCharsInString) { string s; // XXX: Want to write {2, '\0'} but not sure why not. s.append(static_cast(2), '\0'); auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindBlob(1, s); - stmt->BindText(2, s); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindBlob(1, s); + stmt.BindText(2, s); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(2, stmt->ColumnSize(0)); - EXPECT_EQ(2, stmt->ColumnString(0).size()); - EXPECT_EQ('\0', stmt->ColumnString(0).at(0)); - EXPECT_EQ('\0', stmt->ColumnString(0).at(1)); - EXPECT_EQ(2, stmt->ColumnSize(1)); - EXPECT_EQ(2, stmt->ColumnString(1).size()); - EXPECT_EQ('\0', stmt->ColumnString(1).at(0)); - EXPECT_EQ('\0', stmt->ColumnString(1).at(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(2, stmt.ColumnSize(0)); + EXPECT_EQ(2, stmt.ColumnString(0).size()); + EXPECT_EQ('\0', stmt.ColumnString(0).at(0)); + EXPECT_EQ('\0', stmt.ColumnString(0).at(1)); + EXPECT_EQ(2, stmt.ColumnSize(1)); + EXPECT_EQ(2, stmt.ColumnString(1).size()); + EXPECT_EQ('\0', stmt.ColumnString(1).at(0)); + EXPECT_EQ('\0', stmt.ColumnString(1).at(1)); } TEST_F(SqliteTest, Unicode) { string s = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非"; auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindBlob(1, s); - stmt->BindText(2, s); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindBlob(1, s); + stmt.BindText(2, s); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(s, stmt->ColumnString(0)); - EXPECT_EQ(s, stmt->ColumnString(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(s, stmt.ColumnString(0)); + EXPECT_EQ(s, stmt.ColumnString(1)); } TEST_F(SqliteTest, StepAndResetClearsBindings) { auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindInt(1, 1); - stmt->BindInt(2, 123); - TF_ASSERT_OK(stmt->StepAndReset()); - stmt->BindInt(1, 2); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindInt(1, 1); + stmt.BindInt(2, 123); + TF_ASSERT_OK(stmt.StepAndReset()); + stmt.BindInt(1, 2); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT b FROM T ORDER BY a"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(123, stmt->ColumnInt(0)); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(SQLITE_NULL, stmt->ColumnType(0)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(123, stmt.ColumnInt(0)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(SQLITE_NULL, stmt.ColumnType(0)); } TEST_F(SqliteTest, CloseBeforeFinalizeFails) { @@ -128,71 +127,109 @@ TEST_F(SqliteTest, CloseBeforeFinalizeFails) { // is designed to carry the first error state forward to Step(). TEST_F(SqliteTest, ErrorPuntingDoesNotReportLibraryAbuse) { auto stmt = db_->Prepare("lol cat"); - EXPECT_FALSE(stmt->status().ok()); - EXPECT_EQ(SQLITE_ERROR, stmt->error()); - stmt->BindInt(1, 1); - stmt->BindInt(2, 2); - Status s = stmt->Step(&is_done_); - EXPECT_EQ(SQLITE_ERROR, stmt->error()); // first error of several + EXPECT_FALSE(stmt.status().ok()); + EXPECT_EQ(SQLITE_ERROR, stmt.error()); + stmt.BindInt(1, 1); + stmt.BindInt(2, 2); + Status s = stmt.Step(&is_done_); + EXPECT_EQ(SQLITE_ERROR, stmt.error()); // first error of several EXPECT_FALSE(s.ok()); } TEST_F(SqliteTest, SafeBind) { string s = "hello"; auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindBlob(1, s); - stmt->BindText(2, s); + stmt.BindBlob(1, s); + stmt.BindText(2, s); s.at(0) = 'y'; - TF_ASSERT_OK(stmt->StepAndReset()); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ("hello", stmt->ColumnString(0)); - EXPECT_EQ("hello", stmt->ColumnString(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ("hello", stmt.ColumnString(0)); + EXPECT_EQ("hello", stmt.ColumnString(1)); } TEST_F(SqliteTest, UnsafeBind) { string s = "hello"; auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindBlobUnsafe(1, s); - stmt->BindTextUnsafe(2, s); + stmt.BindBlobUnsafe(1, s); + stmt.BindTextUnsafe(2, s); s.at(0) = 'y'; - TF_ASSERT_OK(stmt->StepAndReset()); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT a, b FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ("yello", stmt->ColumnString(0)); - EXPECT_EQ("yello", stmt->ColumnString(1)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ("yello", stmt.ColumnString(0)); + EXPECT_EQ("yello", stmt.ColumnString(1)); } TEST_F(SqliteTest, UnsafeColumn) { auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)"); - stmt->BindInt(1, 1); - stmt->BindText(2, "hello"); - TF_ASSERT_OK(stmt->StepAndReset()); - stmt->BindInt(1, 2); - stmt->BindText(2, "there"); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindInt(1, 1); + stmt.BindText(2, "hello"); + TF_ASSERT_OK(stmt.StepAndReset()); + stmt.BindInt(1, 2); + stmt.BindText(2, "there"); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT b FROM T ORDER BY a"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - const char* p = stmt->ColumnStringUnsafe(0); + TF_ASSERT_OK(stmt.Step(&is_done_)); + const char* p = stmt.ColumnStringUnsafe(0); EXPECT_EQ('h', *p); - TF_ASSERT_OK(stmt->Step(&is_done_)); + TF_ASSERT_OK(stmt.Step(&is_done_)); // This will actually happen, but it's not safe to test this behavior. // EXPECT_EQ('t', *p); } TEST_F(SqliteTest, NamedParameterBind) { auto stmt = db_->Prepare("INSERT INTO T (a) VALUES (:a)"); - stmt->BindText(":a", "lol"); - TF_ASSERT_OK(stmt->StepAndReset()); + stmt.BindText(":a", "lol"); + TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->Prepare("SELECT COUNT(*) FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); - EXPECT_EQ(1, stmt->ColumnInt(0)); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_EQ(1, stmt.ColumnInt(0)); stmt = db_->Prepare("SELECT a FROM T"); - TF_ASSERT_OK(stmt->Step(&is_done_)); + TF_ASSERT_OK(stmt.Step(&is_done_)); EXPECT_FALSE(is_done_); - EXPECT_EQ("lol", stmt->ColumnString(0)); + EXPECT_EQ("lol", stmt.ColumnString(0)); +} + +TEST_F(SqliteTest, Statement_DefaultConstructor) { + SqliteStatement stmt; + EXPECT_FALSE(stmt); + EXPECT_FALSE(stmt.StepAndReset().ok()); + stmt = db_->Prepare("INSERT INTO T (a) VALUES (1)"); + EXPECT_TRUE(stmt); + EXPECT_TRUE(stmt.StepAndReset().ok()); +} + +TEST_F(SqliteTest, Statement_MoveConstructor) { + SqliteStatement stmt{db_->Prepare("INSERT INTO T (a) VALUES (1)")}; + EXPECT_TRUE(stmt.StepAndReset().ok()); +} + +TEST_F(SqliteTest, Statement_MoveAssignment) { + SqliteStatement stmt1 = db_->Prepare("INSERT INTO T (a) VALUES (1)"); + SqliteStatement stmt2; + EXPECT_TRUE(stmt1.StepAndReset().ok()); + EXPECT_FALSE(stmt2.StepAndReset().ok()); + stmt2 = std::move(stmt1); + EXPECT_TRUE(stmt2.StepAndReset().ok()); +} + +TEST_F(SqliteTest, PrepareFailed) { + SqliteStatement s = db_->Prepare("SELECT"); + EXPECT_FALSE(s.status().ok()); + EXPECT_NE(string::npos, s.status().error_message().find("SELECT")); +} + +TEST_F(SqliteTest, BindFailed) { + SqliteStatement s = db_->Prepare("INSERT INTO T (a) VALUES (123)"); + EXPECT_TRUE(s.status().ok()); + EXPECT_EQ("", s.status().error_message()); + s.BindInt(1, 123); + EXPECT_FALSE(s.status().ok()); + EXPECT_NE(string::npos, + s.status().error_message().find("INSERT INTO T (a) VALUES (123)")); } } // namespace -} // namespace db } // namespace tensorflow