This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9822129 test(c/driver): improve coverage of types in select/ingest
(#559)
9822129 is described below
commit 98221290c780edb119833a8df8cf1329347f9cd4
Author: David Li <[email protected]>
AuthorDate: Mon Mar 27 21:16:10 2023 -0400
test(c/driver): improve coverage of types in select/ingest (#559)
Fixes #197.
---
c/driver/flightsql/sqlite_flightsql_test.cc | 4 ++
c/driver/postgresql/postgresql_test.cc | 12 ++++
c/driver/sqlite/sqlite_test.cc | 22 ++++++
c/driver_manager/adbc_driver_manager_test.cc | 22 ++++++
c/validation/adbc_validation.cc | 101 +++++++++++++++++++++++----
c/validation/adbc_validation.h | 52 +++++++++++++-
c/validation/adbc_validation_util.h | 38 +++++++++-
7 files changed, 232 insertions(+), 19 deletions(-)
diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc
b/c/driver/flightsql/sqlite_flightsql_test.cc
index 71d4e99..2bf0441 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -38,6 +38,10 @@ class SqliteFlightSqlQuirks : public
adbc_validation::DriverQuirks {
AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,
struct AdbcError* error) const override {
const char* uri = std::getenv("ADBC_SQLITE_FLIGHTSQL_URI");
+ if (!uri || std::strlen(uri) == 0) {
+ ADD_FAILURE() << "Must set ADBC_SQLITE_FLIGHTSQL_URI";
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri, error),
IsOkStatus(error));
return ADBC_STATUS_OK;
}
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index ba118f0..26d6f6d 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -107,6 +107,18 @@ class PostgresStatementTest : public ::testing::Test,
void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
+ void TestSqlIngestInt8() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestInt16() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestInt32() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestUInt8() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestFloat32() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestFloat64() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestString() { GTEST_SKIP() << "TODO(apache/arrow-adbc#557)"; }
+ void TestSqlIngestBinary() { GTEST_SKIP() << "Not implemented"; }
+
void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet
implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet
implemented"; }
void TestSqlPrepareSelectParams() { GTEST_SKIP() << "Not yet implemented"; }
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 3af4a8b..7e7d840 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -45,6 +45,25 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
std::string BindParameter(int index) const override { return "?"; }
+ ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
+ switch (ingest_type) {
+ case NANOARROW_TYPE_INT8:
+ case NANOARROW_TYPE_INT16:
+ case NANOARROW_TYPE_INT32:
+ case NANOARROW_TYPE_INT64:
+ case NANOARROW_TYPE_UINT8:
+ case NANOARROW_TYPE_UINT16:
+ case NANOARROW_TYPE_UINT32:
+ case NANOARROW_TYPE_UINT64:
+ return NANOARROW_TYPE_INT64;
+ case NANOARROW_TYPE_FLOAT:
+ case NANOARROW_TYPE_DOUBLE:
+ return NANOARROW_TYPE_DOUBLE;
+ default:
+ return ingest_type;
+ }
+ }
+
bool supports_concurrent_statements() const override { return true; }
};
@@ -78,6 +97,9 @@ class SqliteStatementTest : public ::testing::Test,
void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
+ void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of
range)"; }
+ void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not
implemented)"; }
+
protected:
SqliteQuirks quirks_;
};
diff --git a/c/driver_manager/adbc_driver_manager_test.cc
b/c/driver_manager/adbc_driver_manager_test.cc
index 82cc82f..99fa477 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -172,6 +172,25 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
std::string BindParameter(int index) const override { return "?"; }
+ ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
+ switch (ingest_type) {
+ case NANOARROW_TYPE_INT8:
+ case NANOARROW_TYPE_INT16:
+ case NANOARROW_TYPE_INT32:
+ case NANOARROW_TYPE_INT64:
+ case NANOARROW_TYPE_UINT8:
+ case NANOARROW_TYPE_UINT16:
+ case NANOARROW_TYPE_UINT32:
+ case NANOARROW_TYPE_UINT64:
+ return NANOARROW_TYPE_INT64;
+ case NANOARROW_TYPE_FLOAT:
+ case NANOARROW_TYPE_DOUBLE:
+ return NANOARROW_TYPE_DOUBLE;
+ default:
+ return ingest_type;
+ }
+ }
+
bool supports_concurrent_statements() const override { return true; }
};
@@ -205,6 +224,9 @@ class SqliteStatementTest : public ::testing::Test,
void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
+ void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of
range)"; }
+ void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not
implemented)"; }
+
protected:
SqliteQuirks quirks_;
};
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index c99f5cc..ad1bf3d 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -19,7 +19,7 @@
#include <cerrno>
#include <cstring>
-#include <iostream>
+#include <limits>
#include <optional>
#include <string>
#include <string_view>
@@ -867,7 +867,9 @@ void StatementTest::TestRelease() {
ASSERT_EQ(NULL, statement.private_data);
}
-void StatementTest::TestSqlIngestInts() {
+template <typename CType>
+void StatementTest::TestSqlIngestType(ArrowType type,
+ const std::vector<std::optional<CType>>&
values) {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
}
@@ -878,10 +880,9 @@ void StatementTest::TestSqlIngestInts() {
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
- ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}),
IsOkErrno());
- ASSERT_THAT(
- MakeBatch<int64_t>(&schema.value, &array.value, &na_error, {42, -42,
std::nullopt}),
- IsOkErrno());
+ ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno());
+ ASSERT_THAT(MakeBatch<CType>(&schema.value, &array.value, &na_error, values),
+ IsOkErrno());
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
@@ -893,7 +894,8 @@ void StatementTest::TestSqlIngestInts() {
int64_t rows_affected = 0;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected,
&error),
IsOkStatus(&error));
- ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(3),
::testing::Eq(-1)));
+ ASSERT_THAT(rows_affected,
+ ::testing::AnyOf(::testing::Eq(values.size()),
::testing::Eq(-1)));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM
bulk_ingest", &error),
IsOkStatus(&error));
@@ -903,25 +905,98 @@ void StatementTest::TestSqlIngestInts() {
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
- ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));
+ ::testing::AnyOf(::testing::Eq(values.size()),
::testing::Eq(-1)));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
- ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value,
- {{"int64s", NANOARROW_TYPE_INT64,
NULLABLE}}));
+ ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
+ ASSERT_NO_FATAL_FAILURE(
+ CompareSchema(&reader.schema.value, {{"col", round_trip_type,
NULLABLE}}));
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
- ASSERT_EQ(3, reader.array->length);
+ ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);
- ASSERT_NO_FATAL_FAILURE(
- CompareArray<int64_t>(reader.array_view->children[0], {42, -42,
std::nullopt}));
+ if (round_trip_type == type) {
+ // XXX: for now we can't compare values; we would need casting
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<CType>(reader.array_view->children[0], values));
+ }
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
}
}
+template <typename CType>
+void StatementTest::TestSqlIngestNumericType(ArrowType type) {
+ std::vector<std::optional<CType>> values = {
+ std::nullopt,
+ };
+
+ if constexpr (std::is_floating_point_v<CType>) {
+ // XXX: sqlite and others seem to have trouble with extreme
+ // values. Likely a bug on our side, but for now, avoid them.
+ values.push_back(-1.0);
+ values.push_back(1.0);
+ } else {
+ values.push_back(std::numeric_limits<CType>::lowest());
+ values.push_back(std::numeric_limits<CType>::max());
+ }
+
+ return TestSqlIngestType(type, values);
+}
+
+void StatementTest::TestSqlIngestUInt8() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint8_t>(NANOARROW_TYPE_UINT8));
+}
+
+void StatementTest::TestSqlIngestUInt16() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint16_t>(NANOARROW_TYPE_UINT16));
+}
+
+void StatementTest::TestSqlIngestUInt32() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint32_t>(NANOARROW_TYPE_UINT32));
+}
+
+void StatementTest::TestSqlIngestUInt64() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint64_t>(NANOARROW_TYPE_UINT64));
+}
+
+void StatementTest::TestSqlIngestInt8() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int8_t>(NANOARROW_TYPE_INT8));
+}
+
+void StatementTest::TestSqlIngestInt16() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int16_t>(NANOARROW_TYPE_INT16));
+}
+
+void StatementTest::TestSqlIngestInt32() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int32_t>(NANOARROW_TYPE_INT32));
+}
+
+void StatementTest::TestSqlIngestInt64() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int64_t>(NANOARROW_TYPE_INT64));
+}
+
+void StatementTest::TestSqlIngestFloat32() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<float>(NANOARROW_TYPE_FLOAT));
+}
+
+void StatementTest::TestSqlIngestFloat64() {
+
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<double>(NANOARROW_TYPE_DOUBLE));
+}
+
+void StatementTest::TestSqlIngestString() {
+ ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
+ NANOARROW_TYPE_STRING, {std::nullopt, "", "1234", "", "δΎ‹"}));
+}
+
+void StatementTest::TestSqlIngestBinary() {
+ ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
+ NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "",
"\xFE\xFF"}));
+}
+
void StatementTest::TestSqlIngestAppend() {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 5cba030..03d6bf8 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -18,10 +18,13 @@
#ifndef ADBC_VALIDATION_H
#define ADBC_VALIDATION_H
+#include <optional>
#include <string>
+#include <vector>
#include <adbc.h>
#include <gtest/gtest.h>
+#include <nanoarrow/nanoarrow.h>
namespace adbc_validation {
@@ -60,6 +63,12 @@ class DriverQuirks {
/// \brief Return the SQL to reference the bind parameter of the given index
virtual std::string BindParameter(int index) const { return "?"; }
+ /// \brief For a given Arrow type of ingested data, what Arrow type
+ /// will the database return when that column is selected?
+ virtual ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const {
+ return ingest_type;
+ }
+
/// \brief Whether two statements can be used at the same time on a
/// single connection
virtual bool supports_concurrent_statements() const { return false; }
@@ -168,8 +177,28 @@ class StatementTest {
void TestNewInit();
void TestRelease();
- // TODO: these should be parameterized tests
- void TestSqlIngestInts();
+ // ---- Type-specific tests --------------------
+
+ // Integers
+ void TestSqlIngestInt8();
+ void TestSqlIngestInt16();
+ void TestSqlIngestInt32();
+ void TestSqlIngestInt64();
+ void TestSqlIngestUInt8();
+ void TestSqlIngestUInt16();
+ void TestSqlIngestUInt32();
+ void TestSqlIngestUInt64();
+
+ // Floats
+ void TestSqlIngestFloat32();
+ void TestSqlIngestFloat64();
+
+ // Strings
+ void TestSqlIngestString();
+ void TestSqlIngestBinary();
+
+ // ---- End Type-specific tests ----------------
+
void TestSqlIngestAppend();
void TestSqlIngestErrors();
void TestSqlIngestMultipleConnections();
@@ -201,6 +230,12 @@ class StatementTest {
struct AdbcDatabase database;
struct AdbcConnection connection;
struct AdbcStatement statement;
+
+ template <typename CType>
+ void TestSqlIngestType(ArrowType type, const
std::vector<std::optional<CType>>& values);
+
+ template <typename CType>
+ void TestSqlIngestNumericType(ArrowType type);
};
#define ADBCV_TEST_STATEMENT(FIXTURE)
\
@@ -208,7 +243,18 @@ class StatementTest {
ADBCV_STRINGIFY(FIXTURE) " must inherit from StatementTest");
\
TEST_F(FIXTURE, NewInit) { TestNewInit(); }
\
TEST_F(FIXTURE, Release) { TestRelease(); }
\
- TEST_F(FIXTURE, SqlIngestInts) { TestSqlIngestInts(); }
\
+ TEST_F(FIXTURE, SqlIngestInt8) { TestSqlIngestInt8(); }
\
+ TEST_F(FIXTURE, SqlIngestInt16) { TestSqlIngestInt16(); }
\
+ TEST_F(FIXTURE, SqlIngestInt32) { TestSqlIngestInt32(); }
\
+ TEST_F(FIXTURE, SqlIngestInt64) { TestSqlIngestInt64(); }
\
+ TEST_F(FIXTURE, SqlIngestUInt8) { TestSqlIngestUInt8(); }
\
+ TEST_F(FIXTURE, SqlIngestUInt16) { TestSqlIngestUInt16(); }
\
+ TEST_F(FIXTURE, SqlIngestUInt32) { TestSqlIngestUInt32(); }
\
+ TEST_F(FIXTURE, SqlIngestUInt64) { TestSqlIngestUInt64(); }
\
+ TEST_F(FIXTURE, SqlIngestFloat32) { TestSqlIngestFloat32(); }
\
+ TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); }
\
+ TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); }
\
+ TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); }
\
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }
\
TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); }
\
TEST_F(FIXTURE, SqlIngestMultipleConnections) {
TestSqlIngestMultipleConnections(); } \
diff --git a/c/validation/adbc_validation_util.h
b/c/validation/adbc_validation_util.h
index 9c23ab3..66d1a05 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -224,11 +224,22 @@ int MakeArray(struct ArrowArray* parent, struct
ArrowArray* array,
const std::vector<std::optional<T>>& values) {
for (const auto& v : values) {
if (v.has_value()) {
- if constexpr (std::is_same<T, int64_t>::value) {
+ if constexpr (std::is_same<T, int8_t>::value || std::is_same<T,
int16_t>::value ||
+ std::is_same<T, int32_t>::value || std::is_same<T,
int64_t>::value) {
if (int errno_res = ArrowArrayAppendInt(array, *v); errno_res != 0) {
return errno_res;
}
- } else if constexpr (std::is_same<T, double>::value) {
+ // XXX: cpplint gets weird here and thinks this is an unbraced if
+ } else if constexpr (std::is_same<T, // NOLINT(readability/braces)
+ uint8_t>::value ||
+ std::is_same<T, uint16_t>::value ||
+ std::is_same<T, uint32_t>::value ||
+ std::is_same<T, uint64_t>::value) {
+ if (int errno_res = ArrowArrayAppendUInt(array, *v); errno_res != 0) {
+ return errno_res;
+ }
+ } else if constexpr (std::is_same<T, float>::value || //
NOLINT(readability/braces)
+ std::is_same<T, double>::value) {
if (int errno_res = ArrowArrayAppendDouble(array, *v); errno_res != 0)
{
return errno_res;
}
@@ -313,18 +324,39 @@ void CompareArray(struct ArrowArrayView* array,
SCOPED_TRACE("Array index " + std::to_string(i));
if (v.has_value()) {
ASSERT_FALSE(ArrowArrayViewIsNull(array, i));
- if constexpr (std::is_same<T, double>::value) {
+ if constexpr (std::is_same<T, float>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
+ } else if constexpr (std::is_same<T, double>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_double[i]);
} else if constexpr (std::is_same<T, float>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
+ } else if constexpr (std::is_same<T, int8_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]);
+ } else if constexpr (std::is_same<T, int16_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_int16[i]);
} else if constexpr (std::is_same<T, int32_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int32[i]);
} else if constexpr (std::is_same<T, int64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int64[i]);
+ } else if constexpr (std::is_same<T, uint8_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_uint8[i]);
+ } else if constexpr (std::is_same<T, uint16_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_uint16[i]);
+ } else if constexpr (std::is_same<T, uint32_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_uint32[i]);
+ } else if constexpr (std::is_same<T, uint64_t>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, array->buffer_views[1].data.as_uint64[i]);
} else if constexpr (std::is_same<T, std::string>::value) {
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
std::string str(view.data, view.size_bytes);