This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 96a04ae47 fix(c/driver/postgresql): Fix ingest of streams with zero
arrays (#2073)
96a04ae47 is described below
commit 96a04ae47e0158b298ff533ddf3748a5de435042
Author: Dewey Dunnington <[email protected]>
AuthorDate: Mon Aug 12 20:24:50 2024 -0300
fix(c/driver/postgresql): Fix ingest of streams with zero arrays (#2073)
I think this didn't work because we only ever wrote COPY output to the
connection after each array, so if there were no arrays, there was no
header sent!
Closes #2071.
After this PR:
``` r
library(adbcdrivermanager)
#> Warning: package 'adbcdrivermanager' was built under R version 4.3.3
con <- adbc_database_init(
adbcpostgresql::adbcpostgresql(),
uri =
"postgresql://localhost:5432/postgres?user=postgres&password=password"
) |>
adbc_connection_init()
con |>
execute_adbc("DROP TABLE IF EXISTS no_integers")
nanoarrow::basic_array_stream(
list(),
nanoarrow::na_struct(list(x = nanoarrow::na_int32()))
) |>
write_adbc(con, "no_integers")
con |>
read_adbc("SELECT * from no_integers") |>
tibble::as_tibble()
#> # A tibble: 0 × 1
#> # ℹ 1 variable: x <int>
```
<sup>Created on 2024-08-11 with [reprex
v2.1.0](https://reprex.tidyverse.org)</sup>
---
c/driver/postgresql/statement.cc | 51 ++++++++++++++++++-------------
c/driver/sqlite/statement_reader.c | 15 +++++++--
c/validation/adbc_validation.h | 3 ++
c/validation/adbc_validation_statement.cc | 46 ++++++++++++++++++++++++++++
4 files changed, 91 insertions(+), 24 deletions(-)
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 0fa8a79b9..c6e012581 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -551,12 +551,6 @@ struct BindStream {
AdbcStatusCode ExecuteCopy(const PostgresConnection* conn, int64_t*
rows_affected,
struct AdbcError* error) {
- // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max
- // size for a single message that we need to respect (1 GiB - 1). Since
- // the buffer can be chunked up as much as we want, go for 16 MiB as our
- // limit.
- //
https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28
- constexpr int64_t kMaxCopyBufferSize = 0x1000000;
if (rows_affected) *rows_affected = 0;
const auto pg_conn = conn->conn();
@@ -592,26 +586,15 @@ struct BindStream {
return ADBC_STATUS_IO;
}
- ArrowBuffer buffer = writer.WriteBuffer();
- {
- auto* data = reinterpret_cast<char*>(buffer.data);
- int64_t remaining = buffer.size_bytes;
- while (remaining > 0) {
- int64_t to_write = std::min<int64_t>(remaining, kMaxCopyBufferSize);
- if (PQputCopyData(pg_conn, data, to_write) <= 0) {
- SetError(error, "Error writing tuple field data: %s",
- PQerrorMessage(pg_conn));
- return ADBC_STATUS_IO;
- }
- remaining -= to_write;
- data += to_write;
- }
- }
+ RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error));
if (rows_affected) *rows_affected += array->length;
writer.Rewind();
}
+ // If there were no arrays in the stream, we haven't flushed yet
+ RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error));
+
if (PQputCopyEnd(pg_conn, NULL) <= 0) {
SetError(error, "Error message returned by PQputCopyEnd: %s",
PQerrorMessage(pg_conn));
@@ -631,6 +614,32 @@ struct BindStream {
PQclear(result);
return ADBC_STATUS_OK;
}
+
+ AdbcStatusCode FlushCopyWriterToConn(PGconn* pg_conn,
+ const PostgresCopyStreamWriter& writer,
+ struct AdbcError* error) {
+ // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max
+ // size for a single message that we need to respect (1 GiB - 1). Since
+ // the buffer can be chunked up as much as we want, go for 16 MiB as our
+ // limit.
+ //
https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28
+ constexpr int64_t kMaxCopyBufferSize = 0x1000000;
+ ArrowBuffer buffer = writer.WriteBuffer();
+
+ auto* data = reinterpret_cast<char*>(buffer.data);
+ int64_t remaining = buffer.size_bytes;
+ while (remaining > 0) {
+ int64_t to_write = std::min<int64_t>(remaining, kMaxCopyBufferSize);
+ if (PQputCopyData(pg_conn, data, to_write) <= 0) {
+ SetError(error, "Error writing tuple field data: %s",
PQerrorMessage(pg_conn));
+ return ADBC_STATUS_IO;
+ }
+ remaining -= to_write;
+ data += to_write;
+ }
+
+ return ADBC_STATUS_OK;
+ }
};
} // namespace
diff --git a/c/driver/sqlite/statement_reader.c
b/c/driver/sqlite/statement_reader.c
index a832e7bfd..dc036a963 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -629,9 +629,18 @@ int StatementReaderGetNext(struct ArrowArrayStream* self,
struct ArrowArray* out
struct StatementReader* reader = (struct StatementReader*)self->private_data;
if (reader->initial_batch.release != NULL) {
- memcpy(out, &reader->initial_batch, sizeof(*out));
- memset(&reader->initial_batch, 0, sizeof(reader->initial_batch));
- return 0;
+ // Canonically return zero-row results as a stream with zero batches
+ if (reader->initial_batch.length == 0) {
+ reader->initial_batch.release(&reader->initial_batch);
+ reader->done = true;
+
+ out->release = NULL;
+ return 0;
+ } else {
+ memcpy(out, &reader->initial_batch, sizeof(*out));
+ memset(&reader->initial_batch, 0, sizeof(reader->initial_batch));
+ return 0;
+ }
} else if (reader->done) {
out->release = NULL;
return 0;
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 7535d2070..ab665ac10 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -377,6 +377,8 @@ class StatementTest {
// Dictionary-encoded
void TestSqlIngestStringDictionary();
+ void TestSqlIngestStreamZeroArrays();
+
// ---- End Type-specific tests ----------------
void TestSqlIngestTableEscaping();
@@ -478,6 +480,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }
\
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }
\
TEST_F(FIXTURE, SqlIngestStringDictionary) {
TestSqlIngestStringDictionary(); } \
+ TEST_F(FIXTURE, TestSqlIngestStreamZeroArrays) {
TestSqlIngestStreamZeroArrays(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); }
\
TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); }
\
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }
\
diff --git a/c/validation/adbc_validation_statement.cc
b/c/validation/adbc_validation_statement.cc
index 06b379272..431620594 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -491,6 +491,52 @@ void StatementTest::TestSqlIngestStringDictionary() {
/*dictionary_encode*/
true));
}
+void StatementTest::TestSqlIngestStreamZeroArrays() {
+ if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
+ IsOkStatus(&error));
+
+ Handle<struct ArrowSchema> schema;
+ ASSERT_THAT(MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_INT32}}),
IsOkErrno());
+
+ Handle<struct ArrowArrayStream> bind;
+ nanoarrow::EmptyArrayStream(&schema.value).ToArrayStream(&bind.value);
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_ingest", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBindStream(&statement, &bind.value, &error),
+ IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+ IsOkStatus(&error));
+
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(&statement, "SELECT * FROM \"bulk_ingest\"",
&error),
+ IsOkStatus(&error));
+
+ {
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(reader.rows_affected,
+ ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ArrowType round_trip_type =
quirks()->IngestSelectRoundTripType(NANOARROW_TYPE_INT32);
+ ASSERT_NO_FATAL_FAILURE(
+ CompareSchema(&reader.schema.value, {{"col", round_trip_type,
NULLABLE}}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(nullptr, reader.array->release);
+ }
+}
+
void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";