mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Improve C++ SQLite veneer
- Use shared_ptr for Sqlite - Don't need unique_ptr on SqliteStatement - Don't need db namespace - Include SQL in error statuses PiperOrigin-RevId: 173802267
This commit is contained in:
committed by
TensorFlower Gardener
parent
0eba15fe63
commit
325c8e5efa
@@ -191,6 +191,10 @@ file(GLOB_RECURSE tf_core_lib_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/*.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/*.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/public/*.h"
|
||||
# TODO(@jart): Move StatusOr into core.
|
||||
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h"
|
||||
"${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h"
|
||||
)
|
||||
|
||||
file(GLOB tf_core_platform_srcs
|
||||
|
||||
@@ -15,13 +15,11 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/tensorboard/db/schema.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
namespace {
|
||||
|
||||
class SqliteSchema {
|
||||
public:
|
||||
explicit SqliteSchema(Sqlite* db) : db_(db) {}
|
||||
~SqliteSchema() { db_ = nullptr; }
|
||||
explicit SqliteSchema(std::shared_ptr<Sqlite> db) : db_(std::move(db)) {}
|
||||
|
||||
/// \brief Creates Tensors table.
|
||||
///
|
||||
@@ -371,18 +369,18 @@ class SqliteSchema {
|
||||
|
||||
Status Run(const char* sql) {
|
||||
auto stmt = db_->Prepare(sql);
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt->StepAndReset(), sql);
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
Sqlite* db_;
|
||||
std::shared_ptr<Sqlite> db_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SetupTensorboardSqliteDb(Sqlite* db) {
|
||||
SqliteSchema s(db);
|
||||
Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) {
|
||||
SqliteSchema s(std::move(db));
|
||||
TF_RETURN_IF_ERROR(s.CreateTensorsTable());
|
||||
TF_RETURN_IF_ERROR(s.CreateTensorChunksTable());
|
||||
TF_RETURN_IF_ERROR(s.CreateTagsTable());
|
||||
@@ -408,5 +406,4 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace db
|
||||
} // namespace tensorflow
|
||||
|
||||
@@ -15,19 +15,19 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/db/sqlite.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
|
||||
/// \brief Creates TensorBoard SQLite tables and indexes.
|
||||
///
|
||||
/// If they are already created, this has no effect. If schema
|
||||
/// migrations are necessary, they will be performed with logging.
|
||||
Status SetupTensorboardSqliteDb(Sqlite* db);
|
||||
Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db);
|
||||
|
||||
} // namespace db
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_
|
||||
|
||||
@@ -20,15 +20,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
namespace {
|
||||
|
||||
TEST(SchemaTest, SmokeTestTensorboardSchema) {
|
||||
std::unique_ptr<Sqlite> db;
|
||||
TF_ASSERT_OK(Sqlite::Open(":memory:", &db));
|
||||
TF_ASSERT_OK(SetupTensorboardSqliteDb(db.get()));
|
||||
auto db = Sqlite::Open(":memory:").ValueOrDie();
|
||||
TF_ASSERT_OK(SetupTensorboardSqliteDb(db));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace db
|
||||
} // namespace tensorflow
|
||||
|
||||
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/sql/sqlite_query_connection.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@@ -29,17 +30,18 @@ Status SqliteQueryConnection::Open(const string& data_source_name,
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to open query connection: Connection already opeend.");
|
||||
}
|
||||
Status s = db::Sqlite::Open(data_source_name, &db_);
|
||||
auto s = Sqlite::Open(data_source_name);
|
||||
if (s.ok()) {
|
||||
db_ = std::move(s.ValueOrDie());
|
||||
query_ = query;
|
||||
output_types_ = output_types;
|
||||
}
|
||||
return s;
|
||||
return s.status();
|
||||
}
|
||||
|
||||
Status SqliteQueryConnection::Close() {
|
||||
Status s;
|
||||
s.Update(stmt_->Close());
|
||||
s.Update(stmt_.Close());
|
||||
s.Update(db_->Close());
|
||||
return s;
|
||||
}
|
||||
@@ -52,7 +54,7 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
|
||||
return s;
|
||||
}
|
||||
}
|
||||
Status s = stmt_->Step(end_of_sequence);
|
||||
Status s = stmt_.Step(end_of_sequence);
|
||||
if (!*end_of_sequence) {
|
||||
for (int i = 0; i < column_count_; i++) {
|
||||
DataType dt = output_types_[i];
|
||||
@@ -66,9 +68,9 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
|
||||
|
||||
Status SqliteQueryConnection::PrepareQuery() {
|
||||
stmt_ = db_->Prepare(query_);
|
||||
Status s = stmt_->status();
|
||||
Status s = stmt_.status();
|
||||
if (s.ok()) {
|
||||
int column_count = stmt_->ColumnCount();
|
||||
int column_count = stmt_.ColumnCount();
|
||||
if (column_count != output_types_.size()) {
|
||||
return errors::InvalidArgument(tensorflow::strings::Printf(
|
||||
"The number of columns in query (%d) must match the number of "
|
||||
@@ -84,40 +86,40 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
|
||||
const DataType& data_type, int column_index, Tensor* tensor) {
|
||||
switch (data_type) {
|
||||
case DT_STRING:
|
||||
tensor->scalar<string>()() = stmt_->ColumnString(column_index);
|
||||
tensor->scalar<string>()() = stmt_.ColumnString(column_index);
|
||||
break;
|
||||
case DT_INT8:
|
||||
tensor->scalar<int8>()() =
|
||||
static_cast<int8>(stmt_->ColumnInt(column_index));
|
||||
static_cast<int8>(stmt_.ColumnInt(column_index));
|
||||
break;
|
||||
case DT_INT16:
|
||||
tensor->scalar<int16>()() =
|
||||
static_cast<int16>(stmt_->ColumnInt(column_index));
|
||||
static_cast<int16>(stmt_.ColumnInt(column_index));
|
||||
break;
|
||||
case DT_INT32:
|
||||
tensor->scalar<int32>()() =
|
||||
static_cast<int32>(stmt_->ColumnInt(column_index));
|
||||
static_cast<int32>(stmt_.ColumnInt(column_index));
|
||||
break;
|
||||
case DT_INT64:
|
||||
tensor->scalar<int64>()() = stmt_->ColumnInt(column_index);
|
||||
tensor->scalar<int64>()() = stmt_.ColumnInt(column_index);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
tensor->scalar<uint8>()() =
|
||||
static_cast<uint8>(stmt_->ColumnInt(column_index));
|
||||
static_cast<uint8>(stmt_.ColumnInt(column_index));
|
||||
break;
|
||||
case DT_UINT16:
|
||||
tensor->scalar<uint16>()() =
|
||||
static_cast<uint16>(stmt_->ColumnInt(column_index));
|
||||
static_cast<uint16>(stmt_.ColumnInt(column_index));
|
||||
break;
|
||||
case DT_BOOL:
|
||||
tensor->scalar<bool>()() = stmt_->ColumnInt(column_index) != 0;
|
||||
tensor->scalar<bool>()() = stmt_.ColumnInt(column_index) != 0;
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
tensor->scalar<float>()() =
|
||||
static_cast<float>(stmt_->ColumnDouble(column_index));
|
||||
static_cast<float>(stmt_.ColumnDouble(column_index));
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
tensor->scalar<double>()() = stmt_->ColumnDouble(column_index);
|
||||
tensor->scalar<double>()() = stmt_.ColumnDouble(column_index);
|
||||
break;
|
||||
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
||||
default: {
|
||||
|
||||
@@ -42,8 +42,8 @@ class SqliteQueryConnection : public QueryConnection {
|
||||
// `stmt_`.
|
||||
void FillTensorWithResultSetEntry(const DataType& data_type, int column_index,
|
||||
Tensor* tensor);
|
||||
std::unique_ptr<db::Sqlite> db_ = nullptr;
|
||||
std::unique_ptr<db::SqliteStatement> stmt_ = nullptr;
|
||||
std::shared_ptr<Sqlite> db_ = nullptr;
|
||||
SqliteStatement stmt_;
|
||||
int column_count_ = 0;
|
||||
string query_;
|
||||
DataTypeVector output_types_;
|
||||
|
||||
@@ -12,6 +12,7 @@ cc_library(
|
||||
srcs = ["sqlite.cc"],
|
||||
hdrs = ["sqlite.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
"@sqlite_archive//:sqlite",
|
||||
],
|
||||
|
||||
@@ -18,14 +18,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
|
||||
/* static */
|
||||
Status Sqlite::Open(const string& uri, std::unique_ptr<Sqlite>* db) {
|
||||
xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(const string& uri) {
|
||||
sqlite3* sqlite = nullptr;
|
||||
Status s = MakeStatus(sqlite3_open(uri.c_str(), &sqlite));
|
||||
if (s.ok()) {
|
||||
*db = std::unique_ptr<Sqlite>(new Sqlite(sqlite));
|
||||
return std::shared_ptr<Sqlite>(new Sqlite(sqlite));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
@@ -87,6 +86,9 @@ Sqlite::~Sqlite() {
|
||||
}
|
||||
|
||||
Status Sqlite::Close() {
|
||||
if (db_ == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
// If Close is explicitly called, ordering must be correct.
|
||||
Status s = MakeStatus(sqlite3_close(db_));
|
||||
if (s.ok()) {
|
||||
@@ -95,23 +97,42 @@ Status Sqlite::Close() {
|
||||
return s;
|
||||
}
|
||||
|
||||
std::unique_ptr<SqliteStatement> Sqlite::Prepare(const string& sql) {
|
||||
SqliteStatement Sqlite::Prepare(const string& sql) {
|
||||
sqlite3_stmt* stmt = nullptr;
|
||||
int rc = sqlite3_prepare_v2(db_, sql.c_str(), sql.size() + 1, &stmt, nullptr);
|
||||
return std::unique_ptr<SqliteStatement>(new SqliteStatement(stmt, rc));
|
||||
if (rc == SQLITE_OK) {
|
||||
return {stmt, SQLITE_OK, std::unique_ptr<string>(nullptr)};
|
||||
} else {
|
||||
return {nullptr, rc, std::unique_ptr<string>(new string(sql))};
|
||||
}
|
||||
}
|
||||
|
||||
SqliteStatement::SqliteStatement(sqlite3_stmt* stmt, int error)
|
||||
: stmt_(stmt), error_(error) {}
|
||||
Status SqliteStatement::status() const {
|
||||
Status s = Sqlite::MakeStatus(error_);
|
||||
if (!s.ok()) {
|
||||
if (stmt_ != nullptr) {
|
||||
errors::AppendToMessage(&s, sqlite3_sql(stmt_));
|
||||
} else {
|
||||
errors::AppendToMessage(&s, *prepare_error_sql_);
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
SqliteStatement::~SqliteStatement() {
|
||||
int rc = sqlite3_finalize(stmt_);
|
||||
if (rc != SQLITE_OK) {
|
||||
LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
|
||||
void SqliteStatement::CloseOrLog() {
|
||||
if (stmt_ != nullptr) {
|
||||
int rc = sqlite3_finalize(stmt_);
|
||||
if (rc != SQLITE_OK) {
|
||||
LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
|
||||
}
|
||||
stmt_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status SqliteStatement::Close() {
|
||||
if (stmt_ == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
int rc = sqlite3_finalize(stmt_);
|
||||
if (rc == SQLITE_OK) {
|
||||
stmt_ = nullptr;
|
||||
@@ -121,8 +142,10 @@ Status SqliteStatement::Close() {
|
||||
}
|
||||
|
||||
void SqliteStatement::Reset() {
|
||||
sqlite3_reset(stmt_);
|
||||
sqlite3_clear_bindings(stmt_);
|
||||
if (TF_PREDICT_TRUE(stmt_ != nullptr)) {
|
||||
sqlite3_reset(stmt_);
|
||||
sqlite3_clear_bindings(stmt_); // not nullptr friendly
|
||||
}
|
||||
error_ = SQLITE_OK;
|
||||
}
|
||||
|
||||
@@ -163,5 +186,4 @@ Status SqliteStatement::StepAndReset() {
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace db
|
||||
} // namespace tensorflow
|
||||
|
||||
@@ -17,15 +17,16 @@ limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "sqlite3.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
|
||||
class SqliteStatement;
|
||||
|
||||
@@ -46,7 +47,7 @@ class Sqlite {
|
||||
/// `file::memory:` for testing.
|
||||
///
|
||||
/// See https://sqlite.org/c3ref/open.html
|
||||
static Status Open(const string& uri, std::unique_ptr<Sqlite>* db);
|
||||
static xla::StatusOr<std::shared_ptr<Sqlite>> Open(const string& uri);
|
||||
|
||||
/// \brief Makes tensorflow::Status for SQLite result code.
|
||||
///
|
||||
@@ -65,7 +66,7 @@ class Sqlite {
|
||||
/// \brief Frees underlying SQLite object.
|
||||
///
|
||||
/// Unlike the destructor, all SqliteStatement objects must be closed
|
||||
/// beforehand.
|
||||
/// beforehand. This is a no-op if already closed
|
||||
Status Close();
|
||||
|
||||
/// \brief Creates SQLite statement.
|
||||
@@ -74,7 +75,7 @@ class Sqlite {
|
||||
/// failed. It is also possible to punt the error checking to after
|
||||
/// the values have been binded and Step() or ExecuteWriteQuery() is
|
||||
/// called.
|
||||
std::unique_ptr<SqliteStatement> Prepare(const string& sql);
|
||||
SqliteStatement Prepare(const string& sql);
|
||||
|
||||
private:
|
||||
explicit Sqlite(sqlite3* db);
|
||||
@@ -89,21 +90,34 @@ class Sqlite {
|
||||
/// Instances of this class are not thread safe.
|
||||
class SqliteStatement {
|
||||
public:
|
||||
/// \brief Destroys object and finalizes statement if needed.
|
||||
~SqliteStatement();
|
||||
/// \brief Constructs empty statement that should be assigned later.
|
||||
SqliteStatement() : stmt_(nullptr), error_(SQLITE_OK) {}
|
||||
|
||||
/// \brief Empties object and finalizes statement if needed.
|
||||
~SqliteStatement() { CloseOrLog(); }
|
||||
|
||||
/// \brief Move constructor, after which <other> should not be used.
|
||||
SqliteStatement(SqliteStatement&& other);
|
||||
|
||||
/// \brief Move assignment, after which <other> should not be used.
|
||||
SqliteStatement& operator=(SqliteStatement&& other);
|
||||
|
||||
/// \brief Returns true if statement is not empty.
|
||||
operator bool() const { return stmt_ != nullptr; }
|
||||
|
||||
/// \brief Returns SQLite result code state.
|
||||
///
|
||||
/// This will be SQLITE_OK unless an error happened. If multiple
|
||||
/// errors happened, only the first error code will be returned.
|
||||
int error() { return error_; }
|
||||
int error() const { return error_; }
|
||||
|
||||
/// \brief Returns error() as a tensorflow::Status.
|
||||
Status status() { return Sqlite::MakeStatus(error_); }
|
||||
Status status() const;
|
||||
|
||||
/// \brief Finalize statement object.
|
||||
///
|
||||
/// Please note that the destructor can also do this.
|
||||
/// Please note that the destructor can also do this. This method is
|
||||
/// a no-op if already closed.
|
||||
Status Close();
|
||||
|
||||
/// \brief Executes query and/or fetches next row.
|
||||
@@ -247,7 +261,12 @@ class SqliteStatement {
|
||||
|
||||
private:
|
||||
friend Sqlite;
|
||||
SqliteStatement(sqlite3_stmt* stmt, int error); // takes ownership
|
||||
SqliteStatement(sqlite3_stmt* stmt, int error,
|
||||
std::unique_ptr<string> prepare_error_sql)
|
||||
: stmt_(stmt),
|
||||
error_(error),
|
||||
prepare_error_sql_(std::move(prepare_error_sql)) {}
|
||||
void CloseOrLog();
|
||||
|
||||
void Update(int rc) {
|
||||
if (TF_PREDICT_FALSE(rc != SQLITE_OK)) {
|
||||
@@ -268,11 +287,31 @@ class SqliteStatement {
|
||||
|
||||
sqlite3_stmt* stmt_;
|
||||
int error_;
|
||||
std::unique_ptr<string> prepare_error_sql_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SqliteStatement);
|
||||
};
|
||||
|
||||
} // namespace db
|
||||
inline SqliteStatement::SqliteStatement(SqliteStatement&& other)
|
||||
: stmt_(other.stmt_),
|
||||
error_(other.error_),
|
||||
prepare_error_sql_(std::move(other.prepare_error_sql_)) {
|
||||
other.stmt_ = nullptr;
|
||||
other.error_ = SQLITE_OK;
|
||||
}
|
||||
|
||||
inline SqliteStatement& SqliteStatement::operator=(SqliteStatement&& other) {
|
||||
if (&other != this) {
|
||||
CloseOrLog();
|
||||
stmt_ = other.stmt_;
|
||||
error_ = other.error_;
|
||||
prepare_error_sql_ = std::move(other.prepare_error_sql_);
|
||||
other.stmt_ = nullptr;
|
||||
other.error_ = SQLITE_OK;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_LIB_DB_SQLITE_H_
|
||||
|
||||
@@ -24,97 +24,96 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace db {
|
||||
namespace {
|
||||
|
||||
class SqliteTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_ASSERT_OK(Sqlite::Open(":memory:", &db_));
|
||||
db_ = Sqlite::Open(":memory:").ValueOrDie();
|
||||
auto stmt = db_->Prepare("CREATE TABLE T (a BLOB, b BLOB)");
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
}
|
||||
std::unique_ptr<Sqlite> db_;
|
||||
std::shared_ptr<Sqlite> db_;
|
||||
bool is_done_;
|
||||
};
|
||||
|
||||
TEST_F(SqliteTest, InsertAndSelectInt) {
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindInt(1, 3);
|
||||
stmt->BindInt(2, -7);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt->BindInt(1, 123);
|
||||
stmt->BindInt(2, -123);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindInt(1, 3);
|
||||
stmt.BindInt(2, -7);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt.BindInt(1, 123);
|
||||
stmt.BindInt(2, -123);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T ORDER BY b");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
ASSERT_FALSE(is_done_);
|
||||
EXPECT_EQ(123, stmt->ColumnInt(0));
|
||||
EXPECT_EQ(-123, stmt->ColumnInt(1));
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(123, stmt.ColumnInt(0));
|
||||
EXPECT_EQ(-123, stmt.ColumnInt(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
ASSERT_FALSE(is_done_);
|
||||
EXPECT_EQ(3, stmt->ColumnInt(0));
|
||||
EXPECT_EQ(-7, stmt->ColumnInt(1));
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(3, stmt.ColumnInt(0));
|
||||
EXPECT_EQ(-7, stmt.ColumnInt(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
ASSERT_TRUE(is_done_);
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, InsertAndSelectDouble) {
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindDouble(1, 6.28318530);
|
||||
stmt->BindDouble(2, 1.61803399);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindDouble(1, 6.28318530);
|
||||
stmt.BindDouble(2, 1.61803399);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(6.28318530, stmt->ColumnDouble(0));
|
||||
EXPECT_EQ(1.61803399, stmt->ColumnDouble(1));
|
||||
EXPECT_EQ(6, stmt->ColumnInt(0));
|
||||
EXPECT_EQ(1, stmt->ColumnInt(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(6.28318530, stmt.ColumnDouble(0));
|
||||
EXPECT_EQ(1.61803399, stmt.ColumnDouble(1));
|
||||
EXPECT_EQ(6, stmt.ColumnInt(0));
|
||||
EXPECT_EQ(1, stmt.ColumnInt(1));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, NulCharsInString) {
|
||||
string s; // XXX: Want to write {2, '\0'} but not sure why not.
|
||||
s.append(static_cast<size_t>(2), '\0');
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindBlob(1, s);
|
||||
stmt->BindText(2, s);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindBlob(1, s);
|
||||
stmt.BindText(2, s);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(2, stmt->ColumnSize(0));
|
||||
EXPECT_EQ(2, stmt->ColumnString(0).size());
|
||||
EXPECT_EQ('\0', stmt->ColumnString(0).at(0));
|
||||
EXPECT_EQ('\0', stmt->ColumnString(0).at(1));
|
||||
EXPECT_EQ(2, stmt->ColumnSize(1));
|
||||
EXPECT_EQ(2, stmt->ColumnString(1).size());
|
||||
EXPECT_EQ('\0', stmt->ColumnString(1).at(0));
|
||||
EXPECT_EQ('\0', stmt->ColumnString(1).at(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(2, stmt.ColumnSize(0));
|
||||
EXPECT_EQ(2, stmt.ColumnString(0).size());
|
||||
EXPECT_EQ('\0', stmt.ColumnString(0).at(0));
|
||||
EXPECT_EQ('\0', stmt.ColumnString(0).at(1));
|
||||
EXPECT_EQ(2, stmt.ColumnSize(1));
|
||||
EXPECT_EQ(2, stmt.ColumnString(1).size());
|
||||
EXPECT_EQ('\0', stmt.ColumnString(1).at(0));
|
||||
EXPECT_EQ('\0', stmt.ColumnString(1).at(1));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, Unicode) {
|
||||
string s = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非";
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindBlob(1, s);
|
||||
stmt->BindText(2, s);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindBlob(1, s);
|
||||
stmt.BindText(2, s);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(s, stmt->ColumnString(0));
|
||||
EXPECT_EQ(s, stmt->ColumnString(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(s, stmt.ColumnString(0));
|
||||
EXPECT_EQ(s, stmt.ColumnString(1));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, StepAndResetClearsBindings) {
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindInt(1, 1);
|
||||
stmt->BindInt(2, 123);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt->BindInt(1, 2);
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindInt(1, 1);
|
||||
stmt.BindInt(2, 123);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt.BindInt(1, 2);
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(123, stmt->ColumnInt(0));
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(SQLITE_NULL, stmt->ColumnType(0));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(123, stmt.ColumnInt(0));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(SQLITE_NULL, stmt.ColumnType(0));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
|
||||
@@ -128,71 +127,109 @@ TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
|
||||
// is designed to carry the first error state forward to Step().
|
||||
TEST_F(SqliteTest, ErrorPuntingDoesNotReportLibraryAbuse) {
|
||||
auto stmt = db_->Prepare("lol cat");
|
||||
EXPECT_FALSE(stmt->status().ok());
|
||||
EXPECT_EQ(SQLITE_ERROR, stmt->error());
|
||||
stmt->BindInt(1, 1);
|
||||
stmt->BindInt(2, 2);
|
||||
Status s = stmt->Step(&is_done_);
|
||||
EXPECT_EQ(SQLITE_ERROR, stmt->error()); // first error of several
|
||||
EXPECT_FALSE(stmt.status().ok());
|
||||
EXPECT_EQ(SQLITE_ERROR, stmt.error());
|
||||
stmt.BindInt(1, 1);
|
||||
stmt.BindInt(2, 2);
|
||||
Status s = stmt.Step(&is_done_);
|
||||
EXPECT_EQ(SQLITE_ERROR, stmt.error()); // first error of several
|
||||
EXPECT_FALSE(s.ok());
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, SafeBind) {
|
||||
string s = "hello";
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindBlob(1, s);
|
||||
stmt->BindText(2, s);
|
||||
stmt.BindBlob(1, s);
|
||||
stmt.BindText(2, s);
|
||||
s.at(0) = 'y';
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ("hello", stmt->ColumnString(0));
|
||||
EXPECT_EQ("hello", stmt->ColumnString(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ("hello", stmt.ColumnString(0));
|
||||
EXPECT_EQ("hello", stmt.ColumnString(1));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, UnsafeBind) {
|
||||
string s = "hello";
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindBlobUnsafe(1, s);
|
||||
stmt->BindTextUnsafe(2, s);
|
||||
stmt.BindBlobUnsafe(1, s);
|
||||
stmt.BindTextUnsafe(2, s);
|
||||
s.at(0) = 'y';
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT a, b FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ("yello", stmt->ColumnString(0));
|
||||
EXPECT_EQ("yello", stmt->ColumnString(1));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ("yello", stmt.ColumnString(0));
|
||||
EXPECT_EQ("yello", stmt.ColumnString(1));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, UnsafeColumn) {
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
|
||||
stmt->BindInt(1, 1);
|
||||
stmt->BindText(2, "hello");
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt->BindInt(1, 2);
|
||||
stmt->BindText(2, "there");
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindInt(1, 1);
|
||||
stmt.BindText(2, "hello");
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt.BindInt(1, 2);
|
||||
stmt.BindText(2, "there");
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
const char* p = stmt->ColumnStringUnsafe(0);
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
const char* p = stmt.ColumnStringUnsafe(0);
|
||||
EXPECT_EQ('h', *p);
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
// This will actually happen, but it's not safe to test this behavior.
|
||||
// EXPECT_EQ('t', *p);
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, NamedParameterBind) {
|
||||
auto stmt = db_->Prepare("INSERT INTO T (a) VALUES (:a)");
|
||||
stmt->BindText(":a", "lol");
|
||||
TF_ASSERT_OK(stmt->StepAndReset());
|
||||
stmt.BindText(":a", "lol");
|
||||
TF_ASSERT_OK(stmt.StepAndReset());
|
||||
stmt = db_->Prepare("SELECT COUNT(*) FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
EXPECT_EQ(1, stmt->ColumnInt(0));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_EQ(1, stmt.ColumnInt(0));
|
||||
stmt = db_->Prepare("SELECT a FROM T");
|
||||
TF_ASSERT_OK(stmt->Step(&is_done_));
|
||||
TF_ASSERT_OK(stmt.Step(&is_done_));
|
||||
EXPECT_FALSE(is_done_);
|
||||
EXPECT_EQ("lol", stmt->ColumnString(0));
|
||||
EXPECT_EQ("lol", stmt.ColumnString(0));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, Statement_DefaultConstructor) {
|
||||
SqliteStatement stmt;
|
||||
EXPECT_FALSE(stmt);
|
||||
EXPECT_FALSE(stmt.StepAndReset().ok());
|
||||
stmt = db_->Prepare("INSERT INTO T (a) VALUES (1)");
|
||||
EXPECT_TRUE(stmt);
|
||||
EXPECT_TRUE(stmt.StepAndReset().ok());
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, Statement_MoveConstructor) {
|
||||
SqliteStatement stmt{db_->Prepare("INSERT INTO T (a) VALUES (1)")};
|
||||
EXPECT_TRUE(stmt.StepAndReset().ok());
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, Statement_MoveAssignment) {
|
||||
SqliteStatement stmt1 = db_->Prepare("INSERT INTO T (a) VALUES (1)");
|
||||
SqliteStatement stmt2;
|
||||
EXPECT_TRUE(stmt1.StepAndReset().ok());
|
||||
EXPECT_FALSE(stmt2.StepAndReset().ok());
|
||||
stmt2 = std::move(stmt1);
|
||||
EXPECT_TRUE(stmt2.StepAndReset().ok());
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, PrepareFailed) {
|
||||
SqliteStatement s = db_->Prepare("SELECT");
|
||||
EXPECT_FALSE(s.status().ok());
|
||||
EXPECT_NE(string::npos, s.status().error_message().find("SELECT"));
|
||||
}
|
||||
|
||||
TEST_F(SqliteTest, BindFailed) {
|
||||
SqliteStatement s = db_->Prepare("INSERT INTO T (a) VALUES (123)");
|
||||
EXPECT_TRUE(s.status().ok());
|
||||
EXPECT_EQ("", s.status().error_message());
|
||||
s.BindInt(1, 123);
|
||||
EXPECT_FALSE(s.status().ok());
|
||||
EXPECT_NE(string::npos,
|
||||
s.status().error_message().find("INSERT INTO T (a) VALUES (123)"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace db
|
||||
} // namespace tensorflow
|
||||
|
||||
Reference in New Issue
Block a user