This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 96a04ae47 fix(c/driver/postgresql): Fix ingest of streams with zero 
arrays (#2073)
96a04ae47 is described below

commit 96a04ae47e0158b298ff533ddf3748a5de435042
Author: Dewey Dunnington <[email protected]>
AuthorDate: Mon Aug 12 20:24:50 2024 -0300

    fix(c/driver/postgresql): Fix ingest of streams with zero arrays (#2073)
    
    I think this didn't work because we only ever wrote COPY output to the
    connection after each array, so if there were no arrays, there was no
    header sent!
    
    Closes #2071.
    
    After this PR:
    
    ``` r
    library(adbcdrivermanager)
    #> Warning: package 'adbcdrivermanager' was built under R version 4.3.3
    
    con <- adbc_database_init(
      adbcpostgresql::adbcpostgresql(),
      uri = 
"postgresql://localhost:5432/postgres?user=postgres&password=password"
    ) |>
      adbc_connection_init()
    
    con |>
      execute_adbc("DROP TABLE IF EXISTS no_integers")
    
    nanoarrow::basic_array_stream(
      list(),
      nanoarrow::na_struct(list(x = nanoarrow::na_int32()))
    ) |>
      write_adbc(con, "no_integers")
    
    con |>
      read_adbc("SELECT * from no_integers") |>
      tibble::as_tibble()
    #> # A tibble: 0 × 1
    #> # ℹ 1 variable: x <int>
    ```
    
    <sup>Created on 2024-08-11 with [reprex
    v2.1.0](https://reprex.tidyverse.org)</sup>
---
 c/driver/postgresql/statement.cc          | 51 ++++++++++++++++++-------------
 c/driver/sqlite/statement_reader.c        | 15 +++++++--
 c/validation/adbc_validation.h            |  3 ++
 c/validation/adbc_validation_statement.cc | 46 ++++++++++++++++++++++++++++
 4 files changed, 91 insertions(+), 24 deletions(-)

diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 0fa8a79b9..c6e012581 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -551,12 +551,6 @@ struct BindStream {
 
   AdbcStatusCode ExecuteCopy(const PostgresConnection* conn, int64_t* 
rows_affected,
                              struct AdbcError* error) {
-    // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max
-    // size for a single message that we need to respect (1 GiB - 1).  Since
-    // the buffer can be chunked up as much as we want, go for 16 MiB as our
-    // limit.
-    // 
https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28
-    constexpr int64_t kMaxCopyBufferSize = 0x1000000;
     if (rows_affected) *rows_affected = 0;
     const auto pg_conn = conn->conn();
 
@@ -592,26 +586,15 @@ struct BindStream {
         return ADBC_STATUS_IO;
       }
 
-      ArrowBuffer buffer = writer.WriteBuffer();
-      {
-        auto* data = reinterpret_cast<char*>(buffer.data);
-        int64_t remaining = buffer.size_bytes;
-        while (remaining > 0) {
-          int64_t to_write = std::min<int64_t>(remaining, kMaxCopyBufferSize);
-          if (PQputCopyData(pg_conn, data, to_write) <= 0) {
-            SetError(error, "Error writing tuple field data: %s",
-                     PQerrorMessage(pg_conn));
-            return ADBC_STATUS_IO;
-          }
-          remaining -= to_write;
-          data += to_write;
-        }
-      }
+      RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error));
 
       if (rows_affected) *rows_affected += array->length;
       writer.Rewind();
     }
 
+    // If there were no arrays in the stream, we haven't flushed yet
+    RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error));
+
     if (PQputCopyEnd(pg_conn, NULL) <= 0) {
       SetError(error, "Error message returned by PQputCopyEnd: %s",
                PQerrorMessage(pg_conn));
@@ -631,6 +614,32 @@ struct BindStream {
     PQclear(result);
     return ADBC_STATUS_OK;
   }
+
+  AdbcStatusCode FlushCopyWriterToConn(PGconn* pg_conn,
+                                       const PostgresCopyStreamWriter& writer,
+                                       struct AdbcError* error) {
+    // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max
+    // size for a single message that we need to respect (1 GiB - 1).  Since
+    // the buffer can be chunked up as much as we want, go for 16 MiB as our
+    // limit.
+    // 
https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28
+    constexpr int64_t kMaxCopyBufferSize = 0x1000000;
+    ArrowBuffer buffer = writer.WriteBuffer();
+
+    auto* data = reinterpret_cast<char*>(buffer.data);
+    int64_t remaining = buffer.size_bytes;
+    while (remaining > 0) {
+      int64_t to_write = std::min<int64_t>(remaining, kMaxCopyBufferSize);
+      if (PQputCopyData(pg_conn, data, to_write) <= 0) {
+        SetError(error, "Error writing tuple field data: %s", 
PQerrorMessage(pg_conn));
+        return ADBC_STATUS_IO;
+      }
+      remaining -= to_write;
+      data += to_write;
+    }
+
+    return ADBC_STATUS_OK;
+  }
 };
 }  // namespace
 
diff --git a/c/driver/sqlite/statement_reader.c 
b/c/driver/sqlite/statement_reader.c
index a832e7bfd..dc036a963 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -629,9 +629,18 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, 
struct ArrowArray* out
 
   struct StatementReader* reader = (struct StatementReader*)self->private_data;
   if (reader->initial_batch.release != NULL) {
-    memcpy(out, &reader->initial_batch, sizeof(*out));
-    memset(&reader->initial_batch, 0, sizeof(reader->initial_batch));
-    return 0;
+    // Canonically return zero-row results as a stream with zero batches
+    if (reader->initial_batch.length == 0) {
+      reader->initial_batch.release(&reader->initial_batch);
+      reader->done = true;
+
+      out->release = NULL;
+      return 0;
+    } else {
+      memcpy(out, &reader->initial_batch, sizeof(*out));
+      memset(&reader->initial_batch, 0, sizeof(reader->initial_batch));
+      return 0;
+    }
   } else if (reader->done) {
     out->release = NULL;
     return 0;
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 7535d2070..ab665ac10 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -377,6 +377,8 @@ class StatementTest {
   // Dictionary-encoded
   void TestSqlIngestStringDictionary();
 
+  void TestSqlIngestStreamZeroArrays();
+
   // ---- End Type-specific tests ----------------
 
   void TestSqlIngestTableEscaping();
@@ -478,6 +480,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }        
         \
   TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }              
         \
   TEST_F(FIXTURE, SqlIngestStringDictionary) { 
TestSqlIngestStringDictionary(); }       \
+  TEST_F(FIXTURE, TestSqlIngestStreamZeroArrays) { 
TestSqlIngestStreamZeroArrays(); }   \
   TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); }    
         \
   TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); }  
         \
   TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }                  
         \
diff --git a/c/validation/adbc_validation_statement.cc 
b/c/validation/adbc_validation_statement.cc
index 06b379272..431620594 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -491,6 +491,52 @@ void StatementTest::TestSqlIngestStringDictionary() {
                                                          /*dictionary_encode*/ 
true));
 }
 
+void StatementTest::TestSqlIngestStreamZeroArrays() {
+  if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
+    GTEST_SKIP();
+  }
+
+  ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
+              IsOkStatus(&error));
+
+  Handle<struct ArrowSchema> schema;
+  ASSERT_THAT(MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_INT32}}), 
IsOkErrno());
+
+  Handle<struct ArrowArrayStream> bind;
+  nanoarrow::EmptyArrayStream(&schema.value).ToArrayStream(&bind.value);
+
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementSetOption(&statement, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                     "bulk_ingest", &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementBindStream(&statement, &bind.value, &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(
+      AdbcStatementSetSqlQuery(&statement, "SELECT * FROM \"bulk_ingest\"", 
&error),
+      IsOkStatus(&error));
+
+  {
+    StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(reader.rows_affected,
+                ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ArrowType round_trip_type = 
quirks()->IngestSelectRoundTripType(NANOARROW_TYPE_INT32);
+    ASSERT_NO_FATAL_FAILURE(
+        CompareSchema(&reader.schema.value, {{"col", round_trip_type, 
NULLABLE}}));
+
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(nullptr, reader.array->release);
+  }
+}
+
 void StatementTest::TestSqlIngestTableEscaping() {
   std::string name = "create_table_escaping";
 

Reply via email to