Improve C++ SQLite veneer

- Use shared_ptr for Sqlite
- Don't need unique_ptr on SqliteStatement
- Don't need db namespace
- Include SQL in error statuses

PiperOrigin-RevId: 173802267
This commit is contained in:
Justine Tunney
2017-10-28 23:45:55 -07:00
committed by TensorFlower Gardener
parent 0eba15fe63
commit 325c8e5efa
10 changed files with 244 additions and 145 deletions

View File

@@ -191,6 +191,10 @@ file(GLOB_RECURSE tf_core_lib_srcs
"${tensorflow_source_dir}/tensorflow/core/lib/*.h"
"${tensorflow_source_dir}/tensorflow/core/lib/*.cc"
"${tensorflow_source_dir}/tensorflow/core/public/*.h"
# TODO(@jart): Move StatusOr into core.
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc"
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h"
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h"
)
file(GLOB tf_core_platform_srcs

View File

@@ -15,13 +15,11 @@ limitations under the License.
#include "tensorflow/contrib/tensorboard/db/schema.h"
namespace tensorflow {
namespace db {
namespace {
class SqliteSchema {
public:
explicit SqliteSchema(Sqlite* db) : db_(db) {}
~SqliteSchema() { db_ = nullptr; }
explicit SqliteSchema(std::shared_ptr<Sqlite> db) : db_(std::move(db)) {}
/// \brief Creates Tensors table.
///
@@ -371,18 +369,18 @@ class SqliteSchema {
Status Run(const char* sql) {
auto stmt = db_->Prepare(sql);
TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt->StepAndReset(), sql);
TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql);
return Status::OK();
}
private:
Sqlite* db_;
std::shared_ptr<Sqlite> db_;
};
} // namespace
Status SetupTensorboardSqliteDb(Sqlite* db) {
SqliteSchema s(db);
Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) {
SqliteSchema s(std::move(db));
TF_RETURN_IF_ERROR(s.CreateTensorsTable());
TF_RETURN_IF_ERROR(s.CreateTensorChunksTable());
TF_RETURN_IF_ERROR(s.CreateTagsTable());
@@ -408,5 +406,4 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
return Status::OK();
}
} // namespace db
} // namespace tensorflow

View File

@@ -15,19 +15,19 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_
#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_
#include <memory>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/db/sqlite.h"
namespace tensorflow {
namespace db {
/// \brief Creates TensorBoard SQLite tables and indexes.
///
/// If they are already created, this has no effect. If schema
/// migrations are necessary, they will be performed with logging.
Status SetupTensorboardSqliteDb(Sqlite* db);
Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db);
} // namespace db
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_

View File

@@ -20,15 +20,12 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace db {
namespace {
TEST(SchemaTest, SmokeTestTensorboardSchema) {
std::unique_ptr<Sqlite> db;
TF_ASSERT_OK(Sqlite::Open(":memory:", &db));
TF_ASSERT_OK(SetupTensorboardSqliteDb(db.get()));
auto db = Sqlite::Open(":memory:").ValueOrDie();
TF_ASSERT_OK(SetupTensorboardSqliteDb(db));
}
} // namespace
} // namespace db
} // namespace tensorflow

View File

@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/sql/sqlite_query_connection.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
@@ -29,17 +30,18 @@ Status SqliteQueryConnection::Open(const string& data_source_name,
return errors::FailedPrecondition(
"Failed to open query connection: Connection already opeend.");
}
Status s = db::Sqlite::Open(data_source_name, &db_);
auto s = Sqlite::Open(data_source_name);
if (s.ok()) {
db_ = std::move(s.ValueOrDie());
query_ = query;
output_types_ = output_types;
}
return s;
return s.status();
}
Status SqliteQueryConnection::Close() {
Status s;
s.Update(stmt_->Close());
s.Update(stmt_.Close());
s.Update(db_->Close());
return s;
}
@@ -52,7 +54,7 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
return s;
}
}
Status s = stmt_->Step(end_of_sequence);
Status s = stmt_.Step(end_of_sequence);
if (!*end_of_sequence) {
for (int i = 0; i < column_count_; i++) {
DataType dt = output_types_[i];
@@ -66,9 +68,9 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
Status SqliteQueryConnection::PrepareQuery() {
stmt_ = db_->Prepare(query_);
Status s = stmt_->status();
Status s = stmt_.status();
if (s.ok()) {
int column_count = stmt_->ColumnCount();
int column_count = stmt_.ColumnCount();
if (column_count != output_types_.size()) {
return errors::InvalidArgument(tensorflow::strings::Printf(
"The number of columns in query (%d) must match the number of "
@@ -84,40 +86,40 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
const DataType& data_type, int column_index, Tensor* tensor) {
switch (data_type) {
case DT_STRING:
tensor->scalar<string>()() = stmt_->ColumnString(column_index);
tensor->scalar<string>()() = stmt_.ColumnString(column_index);
break;
case DT_INT8:
tensor->scalar<int8>()() =
static_cast<int8>(stmt_->ColumnInt(column_index));
static_cast<int8>(stmt_.ColumnInt(column_index));
break;
case DT_INT16:
tensor->scalar<int16>()() =
static_cast<int16>(stmt_->ColumnInt(column_index));
static_cast<int16>(stmt_.ColumnInt(column_index));
break;
case DT_INT32:
tensor->scalar<int32>()() =
static_cast<int32>(stmt_->ColumnInt(column_index));
static_cast<int32>(stmt_.ColumnInt(column_index));
break;
case DT_INT64:
tensor->scalar<int64>()() = stmt_->ColumnInt(column_index);
tensor->scalar<int64>()() = stmt_.ColumnInt(column_index);
break;
case DT_UINT8:
tensor->scalar<uint8>()() =
static_cast<uint8>(stmt_->ColumnInt(column_index));
static_cast<uint8>(stmt_.ColumnInt(column_index));
break;
case DT_UINT16:
tensor->scalar<uint16>()() =
static_cast<uint16>(stmt_->ColumnInt(column_index));
static_cast<uint16>(stmt_.ColumnInt(column_index));
break;
case DT_BOOL:
tensor->scalar<bool>()() = stmt_->ColumnInt(column_index) != 0;
tensor->scalar<bool>()() = stmt_.ColumnInt(column_index) != 0;
break;
case DT_FLOAT:
tensor->scalar<float>()() =
static_cast<float>(stmt_->ColumnDouble(column_index));
static_cast<float>(stmt_.ColumnDouble(column_index));
break;
case DT_DOUBLE:
tensor->scalar<double>()() = stmt_->ColumnDouble(column_index);
tensor->scalar<double>()() = stmt_.ColumnDouble(column_index);
break;
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
default: {

View File

@@ -42,8 +42,8 @@ class SqliteQueryConnection : public QueryConnection {
// `stmt_`.
void FillTensorWithResultSetEntry(const DataType& data_type, int column_index,
Tensor* tensor);
std::unique_ptr<db::Sqlite> db_ = nullptr;
std::unique_ptr<db::SqliteStatement> stmt_ = nullptr;
std::shared_ptr<Sqlite> db_ = nullptr;
SqliteStatement stmt_;
int column_count_ = 0;
string query_;
DataTypeVector output_types_;

View File

@@ -12,6 +12,7 @@ cc_library(
srcs = ["sqlite.cc"],
hdrs = ["sqlite.h"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
"@sqlite_archive//:sqlite",
],

View File

@@ -18,14 +18,13 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace db {
/* static */
Status Sqlite::Open(const string& uri, std::unique_ptr<Sqlite>* db) {
xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(const string& uri) {
sqlite3* sqlite = nullptr;
Status s = MakeStatus(sqlite3_open(uri.c_str(), &sqlite));
if (s.ok()) {
*db = std::unique_ptr<Sqlite>(new Sqlite(sqlite));
return std::shared_ptr<Sqlite>(new Sqlite(sqlite));
}
return s;
}
@@ -87,6 +86,9 @@ Sqlite::~Sqlite() {
}
Status Sqlite::Close() {
if (db_ == nullptr) {
return Status::OK();
}
// If Close is explicitly called, ordering must be correct.
Status s = MakeStatus(sqlite3_close(db_));
if (s.ok()) {
@@ -95,23 +97,42 @@ Status Sqlite::Close() {
return s;
}
std::unique_ptr<SqliteStatement> Sqlite::Prepare(const string& sql) {
SqliteStatement Sqlite::Prepare(const string& sql) {
sqlite3_stmt* stmt = nullptr;
int rc = sqlite3_prepare_v2(db_, sql.c_str(), sql.size() + 1, &stmt, nullptr);
return std::unique_ptr<SqliteStatement>(new SqliteStatement(stmt, rc));
if (rc == SQLITE_OK) {
return {stmt, SQLITE_OK, std::unique_ptr<string>(nullptr)};
} else {
return {nullptr, rc, std::unique_ptr<string>(new string(sql))};
}
}
SqliteStatement::SqliteStatement(sqlite3_stmt* stmt, int error)
: stmt_(stmt), error_(error) {}
Status SqliteStatement::status() const {
Status s = Sqlite::MakeStatus(error_);
if (!s.ok()) {
if (stmt_ != nullptr) {
errors::AppendToMessage(&s, sqlite3_sql(stmt_));
} else {
errors::AppendToMessage(&s, *prepare_error_sql_);
}
}
return s;
}
SqliteStatement::~SqliteStatement() {
int rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
void SqliteStatement::CloseOrLog() {
if (stmt_ != nullptr) {
int rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
}
stmt_ = nullptr;
}
}
Status SqliteStatement::Close() {
if (stmt_ == nullptr) {
return Status::OK();
}
int rc = sqlite3_finalize(stmt_);
if (rc == SQLITE_OK) {
stmt_ = nullptr;
@@ -121,8 +142,10 @@ Status SqliteStatement::Close() {
}
void SqliteStatement::Reset() {
sqlite3_reset(stmt_);
sqlite3_clear_bindings(stmt_);
if (TF_PREDICT_TRUE(stmt_ != nullptr)) {
sqlite3_reset(stmt_);
sqlite3_clear_bindings(stmt_); // not nullptr friendly
}
error_ = SQLITE_OK;
}
@@ -163,5 +186,4 @@ Status SqliteStatement::StepAndReset() {
return s;
}
} // namespace db
} // namespace tensorflow

View File

@@ -17,15 +17,16 @@ limitations under the License.
#include <stddef.h>
#include <memory>
#include <utility>
#include "sqlite3.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace db {
class SqliteStatement;
@@ -46,7 +47,7 @@ class Sqlite {
/// `file::memory:` for testing.
///
/// See https://sqlite.org/c3ref/open.html
static Status Open(const string& uri, std::unique_ptr<Sqlite>* db);
static xla::StatusOr<std::shared_ptr<Sqlite>> Open(const string& uri);
/// \brief Makes tensorflow::Status for SQLite result code.
///
@@ -65,7 +66,7 @@ class Sqlite {
/// \brief Frees underlying SQLite object.
///
/// Unlike the destructor, all SqliteStatement objects must be closed
/// beforehand.
/// beforehand. This is a no-op if already closed
Status Close();
/// \brief Creates SQLite statement.
@@ -74,7 +75,7 @@ class Sqlite {
/// failed. It is also possible to punt the error checking to after
/// the values have been binded and Step() or ExecuteWriteQuery() is
/// called.
std::unique_ptr<SqliteStatement> Prepare(const string& sql);
SqliteStatement Prepare(const string& sql);
private:
explicit Sqlite(sqlite3* db);
@@ -89,21 +90,34 @@ class Sqlite {
/// Instances of this class are not thread safe.
class SqliteStatement {
public:
/// \brief Destroys object and finalizes statement if needed.
~SqliteStatement();
/// \brief Constructs empty statement that should be assigned later.
SqliteStatement() : stmt_(nullptr), error_(SQLITE_OK) {}
/// \brief Empties object and finalizes statement if needed.
~SqliteStatement() { CloseOrLog(); }
/// \brief Move constructor, after which <other> should not be used.
SqliteStatement(SqliteStatement&& other);
/// \brief Move assignment, after which <other> should not be used.
SqliteStatement& operator=(SqliteStatement&& other);
/// \brief Returns true if statement is not empty.
operator bool() const { return stmt_ != nullptr; }
/// \brief Returns SQLite result code state.
///
/// This will be SQLITE_OK unless an error happened. If multiple
/// errors happened, only the first error code will be returned.
int error() { return error_; }
int error() const { return error_; }
/// \brief Returns error() as a tensorflow::Status.
Status status() { return Sqlite::MakeStatus(error_); }
Status status() const;
/// \brief Finalize statement object.
///
/// Please note that the destructor can also do this.
/// Please note that the destructor can also do this. This method is
/// a no-op if already closed.
Status Close();
/// \brief Executes query and/or fetches next row.
@@ -247,7 +261,12 @@ class SqliteStatement {
private:
friend Sqlite;
SqliteStatement(sqlite3_stmt* stmt, int error); // takes ownership
SqliteStatement(sqlite3_stmt* stmt, int error,
std::unique_ptr<string> prepare_error_sql)
: stmt_(stmt),
error_(error),
prepare_error_sql_(std::move(prepare_error_sql)) {}
void CloseOrLog();
void Update(int rc) {
if (TF_PREDICT_FALSE(rc != SQLITE_OK)) {
@@ -268,11 +287,31 @@ class SqliteStatement {
sqlite3_stmt* stmt_;
int error_;
std::unique_ptr<string> prepare_error_sql_;
TF_DISALLOW_COPY_AND_ASSIGN(SqliteStatement);
};
} // namespace db
inline SqliteStatement::SqliteStatement(SqliteStatement&& other)
: stmt_(other.stmt_),
error_(other.error_),
prepare_error_sql_(std::move(other.prepare_error_sql_)) {
other.stmt_ = nullptr;
other.error_ = SQLITE_OK;
}
inline SqliteStatement& SqliteStatement::operator=(SqliteStatement&& other) {
if (&other != this) {
CloseOrLog();
stmt_ = other.stmt_;
error_ = other.error_;
prepare_error_sql_ = std::move(other.prepare_error_sql_);
other.stmt_ = nullptr;
other.error_ = SQLITE_OK;
}
return *this;
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_LIB_DB_SQLITE_H_

View File

@@ -24,97 +24,96 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace db {
namespace {
class SqliteTest : public ::testing::Test {
protected:
void SetUp() override {
TF_ASSERT_OK(Sqlite::Open(":memory:", &db_));
db_ = Sqlite::Open(":memory:").ValueOrDie();
auto stmt = db_->Prepare("CREATE TABLE T (a BLOB, b BLOB)");
TF_ASSERT_OK(stmt->StepAndReset());
TF_ASSERT_OK(stmt.StepAndReset());
}
std::unique_ptr<Sqlite> db_;
std::shared_ptr<Sqlite> db_;
bool is_done_;
};
TEST_F(SqliteTest, InsertAndSelectInt) {
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindInt(1, 3);
stmt->BindInt(2, -7);
TF_ASSERT_OK(stmt->StepAndReset());
stmt->BindInt(1, 123);
stmt->BindInt(2, -123);
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindInt(1, 3);
stmt.BindInt(2, -7);
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 123);
stmt.BindInt(2, -123);
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T ORDER BY b");
TF_ASSERT_OK(stmt->Step(&is_done_));
TF_ASSERT_OK(stmt.Step(&is_done_));
ASSERT_FALSE(is_done_);
EXPECT_EQ(123, stmt->ColumnInt(0));
EXPECT_EQ(-123, stmt->ColumnInt(1));
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(123, stmt.ColumnInt(0));
EXPECT_EQ(-123, stmt.ColumnInt(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
ASSERT_FALSE(is_done_);
EXPECT_EQ(3, stmt->ColumnInt(0));
EXPECT_EQ(-7, stmt->ColumnInt(1));
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(3, stmt.ColumnInt(0));
EXPECT_EQ(-7, stmt.ColumnInt(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
ASSERT_TRUE(is_done_);
}
TEST_F(SqliteTest, InsertAndSelectDouble) {
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindDouble(1, 6.28318530);
stmt->BindDouble(2, 1.61803399);
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindDouble(1, 6.28318530);
stmt.BindDouble(2, 1.61803399);
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(6.28318530, stmt->ColumnDouble(0));
EXPECT_EQ(1.61803399, stmt->ColumnDouble(1));
EXPECT_EQ(6, stmt->ColumnInt(0));
EXPECT_EQ(1, stmt->ColumnInt(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(6.28318530, stmt.ColumnDouble(0));
EXPECT_EQ(1.61803399, stmt.ColumnDouble(1));
EXPECT_EQ(6, stmt.ColumnInt(0));
EXPECT_EQ(1, stmt.ColumnInt(1));
}
TEST_F(SqliteTest, NulCharsInString) {
string s; // XXX: Want to write {2, '\0'} but not sure why not.
s.append(static_cast<size_t>(2), '\0');
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindBlob(1, s);
stmt->BindText(2, s);
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindBlob(1, s);
stmt.BindText(2, s);
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(2, stmt->ColumnSize(0));
EXPECT_EQ(2, stmt->ColumnString(0).size());
EXPECT_EQ('\0', stmt->ColumnString(0).at(0));
EXPECT_EQ('\0', stmt->ColumnString(0).at(1));
EXPECT_EQ(2, stmt->ColumnSize(1));
EXPECT_EQ(2, stmt->ColumnString(1).size());
EXPECT_EQ('\0', stmt->ColumnString(1).at(0));
EXPECT_EQ('\0', stmt->ColumnString(1).at(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(2, stmt.ColumnSize(0));
EXPECT_EQ(2, stmt.ColumnString(0).size());
EXPECT_EQ('\0', stmt.ColumnString(0).at(0));
EXPECT_EQ('\0', stmt.ColumnString(0).at(1));
EXPECT_EQ(2, stmt.ColumnSize(1));
EXPECT_EQ(2, stmt.ColumnString(1).size());
EXPECT_EQ('\0', stmt.ColumnString(1).at(0));
EXPECT_EQ('\0', stmt.ColumnString(1).at(1));
}
TEST_F(SqliteTest, Unicode) {
string s = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非";
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindBlob(1, s);
stmt->BindText(2, s);
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindBlob(1, s);
stmt.BindText(2, s);
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(s, stmt->ColumnString(0));
EXPECT_EQ(s, stmt->ColumnString(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(s, stmt.ColumnString(0));
EXPECT_EQ(s, stmt.ColumnString(1));
}
TEST_F(SqliteTest, StepAndResetClearsBindings) {
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindInt(1, 1);
stmt->BindInt(2, 123);
TF_ASSERT_OK(stmt->StepAndReset());
stmt->BindInt(1, 2);
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindInt(1, 1);
stmt.BindInt(2, 123);
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 2);
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(123, stmt->ColumnInt(0));
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(SQLITE_NULL, stmt->ColumnType(0));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(123, stmt.ColumnInt(0));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(SQLITE_NULL, stmt.ColumnType(0));
}
TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
@@ -128,71 +127,109 @@ TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
// is designed to carry the first error state forward to Step().
TEST_F(SqliteTest, ErrorPuntingDoesNotReportLibraryAbuse) {
auto stmt = db_->Prepare("lol cat");
EXPECT_FALSE(stmt->status().ok());
EXPECT_EQ(SQLITE_ERROR, stmt->error());
stmt->BindInt(1, 1);
stmt->BindInt(2, 2);
Status s = stmt->Step(&is_done_);
EXPECT_EQ(SQLITE_ERROR, stmt->error()); // first error of several
EXPECT_FALSE(stmt.status().ok());
EXPECT_EQ(SQLITE_ERROR, stmt.error());
stmt.BindInt(1, 1);
stmt.BindInt(2, 2);
Status s = stmt.Step(&is_done_);
EXPECT_EQ(SQLITE_ERROR, stmt.error()); // first error of several
EXPECT_FALSE(s.ok());
}
TEST_F(SqliteTest, SafeBind) {
string s = "hello";
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindBlob(1, s);
stmt->BindText(2, s);
stmt.BindBlob(1, s);
stmt.BindText(2, s);
s.at(0) = 'y';
TF_ASSERT_OK(stmt->StepAndReset());
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ("hello", stmt->ColumnString(0));
EXPECT_EQ("hello", stmt->ColumnString(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ("hello", stmt.ColumnString(0));
EXPECT_EQ("hello", stmt.ColumnString(1));
}
TEST_F(SqliteTest, UnsafeBind) {
string s = "hello";
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindBlobUnsafe(1, s);
stmt->BindTextUnsafe(2, s);
stmt.BindBlobUnsafe(1, s);
stmt.BindTextUnsafe(2, s);
s.at(0) = 'y';
TF_ASSERT_OK(stmt->StepAndReset());
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT a, b FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ("yello", stmt->ColumnString(0));
EXPECT_EQ("yello", stmt->ColumnString(1));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ("yello", stmt.ColumnString(0));
EXPECT_EQ("yello", stmt.ColumnString(1));
}
TEST_F(SqliteTest, UnsafeColumn) {
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
stmt->BindInt(1, 1);
stmt->BindText(2, "hello");
TF_ASSERT_OK(stmt->StepAndReset());
stmt->BindInt(1, 2);
stmt->BindText(2, "there");
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindInt(1, 1);
stmt.BindText(2, "hello");
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 2);
stmt.BindText(2, "there");
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
TF_ASSERT_OK(stmt->Step(&is_done_));
const char* p = stmt->ColumnStringUnsafe(0);
TF_ASSERT_OK(stmt.Step(&is_done_));
const char* p = stmt.ColumnStringUnsafe(0);
EXPECT_EQ('h', *p);
TF_ASSERT_OK(stmt->Step(&is_done_));
TF_ASSERT_OK(stmt.Step(&is_done_));
// This will actually happen, but it's not safe to test this behavior.
// EXPECT_EQ('t', *p);
}
TEST_F(SqliteTest, NamedParameterBind) {
auto stmt = db_->Prepare("INSERT INTO T (a) VALUES (:a)");
stmt->BindText(":a", "lol");
TF_ASSERT_OK(stmt->StepAndReset());
stmt.BindText(":a", "lol");
TF_ASSERT_OK(stmt.StepAndReset());
stmt = db_->Prepare("SELECT COUNT(*) FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
EXPECT_EQ(1, stmt->ColumnInt(0));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(1, stmt.ColumnInt(0));
stmt = db_->Prepare("SELECT a FROM T");
TF_ASSERT_OK(stmt->Step(&is_done_));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_FALSE(is_done_);
EXPECT_EQ("lol", stmt->ColumnString(0));
EXPECT_EQ("lol", stmt.ColumnString(0));
}
TEST_F(SqliteTest, Statement_DefaultConstructor) {
SqliteStatement stmt;
EXPECT_FALSE(stmt);
EXPECT_FALSE(stmt.StepAndReset().ok());
stmt = db_->Prepare("INSERT INTO T (a) VALUES (1)");
EXPECT_TRUE(stmt);
EXPECT_TRUE(stmt.StepAndReset().ok());
}
TEST_F(SqliteTest, Statement_MoveConstructor) {
SqliteStatement stmt{db_->Prepare("INSERT INTO T (a) VALUES (1)")};
EXPECT_TRUE(stmt.StepAndReset().ok());
}
TEST_F(SqliteTest, Statement_MoveAssignment) {
SqliteStatement stmt1 = db_->Prepare("INSERT INTO T (a) VALUES (1)");
SqliteStatement stmt2;
EXPECT_TRUE(stmt1.StepAndReset().ok());
EXPECT_FALSE(stmt2.StepAndReset().ok());
stmt2 = std::move(stmt1);
EXPECT_TRUE(stmt2.StepAndReset().ok());
}
TEST_F(SqliteTest, PrepareFailed) {
SqliteStatement s = db_->Prepare("SELECT");
EXPECT_FALSE(s.status().ok());
EXPECT_NE(string::npos, s.status().error_message().find("SELECT"));
}
TEST_F(SqliteTest, BindFailed) {
SqliteStatement s = db_->Prepare("INSERT INTO T (a) VALUES (123)");
EXPECT_TRUE(s.status().ok());
EXPECT_EQ("", s.status().error_message());
s.BindInt(1, 123);
EXPECT_FALSE(s.status().ok());
EXPECT_NE(string::npos,
s.status().error_message().find("INSERT INTO T (a) VALUES (123)"));
}
} // namespace
} // namespace db
} // namespace tensorflow