This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 96f92463 feat(c/driver/postgresql): Implement streaming/chunked output 
(#870)
96f92463 is described below

commit 96f92463301ddfd278bb7e57c4df58d0cc2f21a4
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Jul 7 11:08:19 2023 -0300

    feat(c/driver/postgresql): Implement streaming/chunked output (#870)
    
    I'm not sure if this is exactly required and I'm happy to implement
    differently. One limitation of the existing driver is that (1) it will
    error if more than 2GB of total text exists in one column and (2) the
    array stream's get_next will block until the entire result has been
    computed. A cool thing you can do in R is do something like `read_adbc()
    |> arrow::as_record_batch_reader() |> arrow::write_dataset()` for
    bigger-than-memory queries...this behaviour is basically to support
    that.
    
    Some open questions:
    
    - Default chunk size? I chose 16 MB...maybe it should be bigger?
    Smaller? I like MB instead of number-of rows because it doesn't make
    assumptions about how big or small the rows are. When querying a PostGIS
    table, for example, polygon features can be several MB each.
    - How to configure the chunk size? Should there be a canonical statement
    option for this or should the postgres driver make up its own?
    - How to test? (Easier if it's configurable since we can set the limit
    to an obscenely low value and drip rows one at a time)
    
    Reprex with R/the default chunk size set to 1024 bytes:
    
    ``` r
    library(adbcdrivermanager)
    
    uri <- Sys.getenv("ADBC_POSTGRESQL_TEST_URI")
    db <- adbc_database_init(adbcpostgresql::adbcpostgresql(), uri = uri)
    con <- adbc_connection_init(db)
    
    rdr <- con |>
      read_adbc("SELECT * from flights")
    
    tibble::as_tibble(rdr$get_next())
    #> # A tibble: 6 × 18
    #>    year month   day dep_time sched_dep_time dep_delay arr_time 
sched_arr_time
    #>   <int> <int> <int>    <int>          <int>     <dbl>    <int>          
<int>
    #> 1  2013     1     1      517            515         2      830           
 819
    #> 2  2013     1     1      533            529         4      850           
 830
    #> 3  2013     1     1      542            540         2      923           
 850
    #> 4  2013     1     1      544            545        -1     1004           
1022
    #> 5  2013     1     1      554            600        -6      812           
 837
    #> 6  2013     1     1      554            558        -4      740           
 728
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance 
<dbl>,
    #> #   hour <dbl>, minute <dbl>
    tibble::as_tibble(rdr$get_next())
    #> # A tibble: 6 × 18
    #>    year month   day dep_time sched_dep_time dep_delay arr_time 
sched_arr_time
    #>   <int> <int> <int>    <int>          <int>     <dbl>    <int>          
<int>
    #> 1  2013     1     1      555            600        -5      913           
 854
    #> 2  2013     1     1      557            600        -3      709           
 723
    #> 3  2013     1     1      557            600        -3      838           
 846
    #> 4  2013     1     1      558            600        -2      753           
 745
    #> 5  2013     1     1      558            600        -2      849           
 851
    #> 6  2013     1     1      558            600        -2      853           
 856
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance 
<dbl>,
    #> #   hour <dbl>, minute <dbl>
    tibble::as_tibble(rdr$get_next())
    #> # A tibble: 6 × 18
    #>    year month   day dep_time sched_dep_time dep_delay arr_time 
sched_arr_time
    #>   <int> <int> <int>    <int>          <int>     <dbl>    <int>          
<int>
    #> 1  2013     1     1      558            600        -2      924           
 917
    #> 2  2013     1     1      558            600        -2      923           
 937
    #> 3  2013     1     1      559            600        -1      941           
 910
    #> 4  2013     1     1      559            559         0      702           
 706
    #> 5  2013     1     1      559            600        -1      854           
 902
    #> 6  2013     1     1      600            600         0      851           
 858
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance 
<dbl>,
    #> #   hour <dbl>, minute <dbl>
    ```
    
    <sup>Created on 2023-07-05 with [reprex
    v2.0.2](https://reprex.tidyverse.org)</sup>
---
 c/driver/postgresql/postgres_copy_reader.h |   7 ++
 c/driver/postgresql/postgresql_test.cc     |  44 +++++++
 c/driver/postgresql/statement.cc           | 182 ++++++++++++++++++++---------
 c/driver/postgresql/statement.h            |  20 +++-
 4 files changed, 198 insertions(+), 55 deletions(-)

diff --git a/c/driver/postgresql/postgres_copy_reader.h 
b/c/driver/postgresql/postgres_copy_reader.h
index 813fd50a..18d1fbd4 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -679,9 +679,12 @@ class PostgresCopyStreamReader {
     }
 
     root_reader_.Init(pg_type);
+    array_size_approx_bytes_ = 0;
     return NANOARROW_OK;
   }
 
+  int64_t array_size_approx_bytes() const { return array_size_approx_bytes_; }
+
   ArrowErrorCode SetOutputSchema(ArrowSchema* schema, ArrowError* error) {
     if (std::string(schema_->format) != "+s") {
       ArrowErrorSet(
@@ -776,9 +779,12 @@ class PostgresCopyStreamReader {
           ArrowArrayInitFromSchema(array_.get(), schema_.get(), error));
       NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array_.get()));
       NANOARROW_RETURN_NOT_OK(root_reader_.InitArray(array_.get()));
+      array_size_approx_bytes_ = 0;
     }
 
+    const uint8_t* start = data->data.as_uint8;
     NANOARROW_RETURN_NOT_OK(root_reader_.Read(data, -1, array_.get(), error));
+    array_size_approx_bytes_ += (data->data.as_uint8 - start);
     return NANOARROW_OK;
   }
 
@@ -800,6 +806,7 @@ class PostgresCopyStreamReader {
   PostgresCopyFieldTupleReader root_reader_;
   nanoarrow::UniqueSchema schema_;
   nanoarrow::UniqueArray array_;
+  int64_t array_size_approx_bytes_;
 };
 
 }  // namespace adbcpq
diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index 159d2ab5..153d8eb2 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -638,6 +638,50 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
   }
 }
 
+TEST_F(PostgresStatementTest, BatchSizeHint) {
+  ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", 
&error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  // Setting the batch size hint to a negative or non-integer value should fail
+  ASSERT_EQ(AdbcStatementSetOption(&statement, 
"adbc.postgresql.batch_size_hint_bytes",
+                                   "-1", nullptr),
+            ADBC_STATUS_INVALID_ARGUMENT);
+  ASSERT_EQ(AdbcStatementSetOption(&statement, 
"adbc.postgresql.batch_size_hint_bytes",
+                                   "not a valid number", nullptr),
+            ADBC_STATUS_INVALID_ARGUMENT);
+
+  // For this test, use a batch size of 1 byte to force every row to be its 
own batch
+  ASSERT_THAT(AdbcStatementSetOption(&statement, 
"adbc.postgresql.batch_size_hint_bytes",
+                                     "1", &error),
+              IsOkStatus(&error));
+
+  {
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(
+            &statement, "SELECT int64s from batch_size_hint_test ORDER BY 
int64s LIMIT 3",
+            &error),
+        IsOkStatus(&error));
+
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->length, 1);
+    ASSERT_EQ(ArrowArrayViewGetIntUnsafe(reader.array_view->children[0], 0), 
-42);
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->length, 1);
+    ASSERT_EQ(ArrowArrayViewGetIntUnsafe(reader.array_view->children[0], 0), 
42);
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->length, 1);
+    ASSERT_TRUE(ArrowArrayViewIsNull(reader.array_view->children[0], 0));
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+}
+
 struct TypeTestCase {
   std::string name;
   std::string sql_type;
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 8bd80a3c..73141362 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -438,91 +438,156 @@ int TupleReader::GetSchema(struct ArrowSchema* out) {
   return na_res;
 }
 
-int TupleReader::GetNext(struct ArrowArray* out) {
-  if (!result_) {
-    out->release = nullptr;
-    return 0;
+int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) {
+  ResetQuery();
+
+  // Fetch + parse the header
+  int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
+  data_.size_bytes = get_copy_res;
+  data_.data.as_char = pgbuf_;
+
+  if (get_copy_res == -2) {
+    StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s",
+                        PQerrorMessage(conn_));
+    return EIO;
   }
 
-  // Clear the result, since the data is actually read from the connection
-  PQclear(result_);
-  result_ = nullptr;
+  int na_res = copy_reader_->ReadHeader(&data_, error);
+  if (na_res != NANOARROW_OK) {
+    StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", 
error->message);
+    return EIO;
+  }
 
-  // Clear the error builder
-  error_builder_.size = 0;
+  return NANOARROW_OK;
+}
 
-  struct ArrowError error;
-  error.message[0] = '\0';
-  struct ArrowBufferView data;
-  data.data.data = nullptr;
-  data.size_bytes = 0;
+int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) {
+  // Parse the result (the header AND the first row are included in the first
+  // call to PQgetCopyData())
+  int na_res = copy_reader_->ReadRecord(&data_, error);
+  if (na_res != NANOARROW_OK && na_res != ENODATA) {
+    StringBuilderAppend(&error_builder_,
+                        "[libpq] ReadRecord failed at row %" PRId64 ": %s", 
row_id_,
+                        error->message);
+    return na_res;
+  }
 
-  // Fetch + parse the header
+  row_id_++;
+
+  // Fetch + check
+  PQfreemem(pgbuf_);
+  pgbuf_ = nullptr;
   int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
+  data_.size_bytes = get_copy_res;
+  data_.data.as_char = pgbuf_;
+
   if (get_copy_res == -2) {
-    StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s",
+    StringBuilderAppend(&error_builder_,
+                        "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", 
row_id_,
                         PQerrorMessage(conn_));
     return EIO;
+  } else if (get_copy_res == -1) {
+    // Returned when COPY has finished successfully
+    return ENODATA;
+  } else if ((copy_reader_->array_size_approx_bytes() + get_copy_res) >=
+             batch_size_hint_bytes_) {
+    // Appending the next row will result in an array larger than requested.
+    // Return EOVERFLOW to force GetNext() to build the current result and 
return.
+    return EOVERFLOW;
+  } else {
+    return NANOARROW_OK;
+  }
+}
+
+int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) 
{
+  if (copy_reader_->array_size_approx_bytes() == 0) {
+    out->release = nullptr;
+    return NANOARROW_OK;
   }
 
-  data.size_bytes = get_copy_res;
-  data.data.as_char = pgbuf_;
-  int na_res = copy_reader_->ReadHeader(&data, &error);
+  int na_res = copy_reader_->GetArray(out, error);
   if (na_res != NANOARROW_OK) {
-    StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", 
error.message);
+    StringBuilderAppend(&error_builder_, "[libpq] Failed to build result 
array: %s",
+                        error->message);
     return na_res;
   }
 
-  int64_t row_id = 0;
-  do {
-    // Parse the result (the header AND the first row are included in the first
-    // call to PQgetCopyData())
-    na_res = copy_reader_->ReadRecord(&data, &error);
-    if (na_res != NANOARROW_OK && na_res != ENODATA) {
-      StringBuilderAppend(&error_builder_, "[libpq] ReadRecord failed at row 
%ld: %s",
-                          static_cast<long>(row_id),  // NOLINT(runtime/int)
-                          error.message);
-      return na_res;
-    }
+  return NANOARROW_OK;
+}
 
-    row_id++;
+void TupleReader::ResetQuery() {
+  // Clear result
+  if (result_) {
+    PQclear(result_);
+    result_ = nullptr;
+  }
 
-    // Fetch + check
+  // Reset result buffer
+  if (pgbuf_ != nullptr) {
     PQfreemem(pgbuf_);
     pgbuf_ = nullptr;
-    get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
-    if (get_copy_res == -2) {
-      StringBuilderAppend(&error_builder_, "[libpq] Fetch row %ld failed: %s",
-                          static_cast<long>(row_id),  // NOLINT(runtime/int)
-                          PQerrorMessage(conn_));
-      return EIO;
-    } else if (get_copy_res == -1) {
-      // Returned when COPY has finished
-      break;
-    }
+  }
+
+  // Clear the error builder
+  error_builder_.size = 0;
 
-    data.size_bytes = get_copy_res;
-    data.data.as_char = pgbuf_;
-  } while (true);
+  row_id_ = -1;
+}
 
-  na_res = copy_reader_->GetArray(out, &error);
-  if (na_res != NANOARROW_OK) {
-    StringBuilderAppend(&error_builder_, "[libpq] Failed to build result 
array: %s",
-                        error.message);
+int TupleReader::GetNext(struct ArrowArray* out) {
+  if (!copy_reader_) {
+    out->release = nullptr;
+    return 0;
+  }
+
+  struct ArrowError error;
+  error.message[0] = '\0';
+
+  if (row_id_ == -1) {
+    NANOARROW_RETURN_NOT_OK(InitQueryAndFetchFirst(&error));
+    row_id_++;
+  }
+
+  int na_res;
+  do {
+    na_res = AppendRowAndFetchNext(&error);
+    if (na_res == EOVERFLOW) {
+      // The result would be too big to return if we appended the row. When 
EOVERFLOW is
+      // returned, the copy reader leaves the output in a valid state. The 
data is left in
+      // pg_buf_/data_ and will attempt to be appended on the next call to 
GetNext()
+      return BuildOutput(out, &error);
+    }
+  } while (na_res == NANOARROW_OK);
+
+  if (na_res != ENODATA) {
     return na_res;
   }
 
+  // Finish the result properly and return the last result. Note that 
BuildOutput() may
+  // set tmp.release = nullptr if there were zero rows in the copy reader (can
+  // occur in an overflow scenario).
+  struct ArrowArray tmp;
+  NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error));
+
+  // Clear the copy reader to mark this reader as finished
+  copy_reader_.reset();
+
   // Check the server-side response
   result_ = PQgetResult(conn_);
   const int pq_status = PQresultStatus(result_);
   if (pq_status != PGRES_COMMAND_OK) {
     StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", 
pq_status,
                         PQresultErrorMessage(result_));
+
+    if (tmp.release != nullptr) {
+      tmp.release(&tmp);
+    }
+
     return EIO;
   }
 
-  PQclear(result_);
-  result_ = nullptr;
+  ResetQuery();
+  ArrowArrayMove(&tmp, out);
   return NANOARROW_OK;
 }
 
@@ -533,6 +598,7 @@ void TupleReader::Release() {
     PQclear(result_);
     result_ = nullptr;
   }
+
   if (pgbuf_) {
     PQfreemem(pgbuf_);
     pgbuf_ = nullptr;
@@ -934,11 +1000,19 @@ AdbcStatusCode PostgresStatement::SetOption(const char* 
key, const char* value,
     } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) {
       ingest_.append = true;
     } else {
-      SetError(error, "%s%s%s%s", "[libpq] Invalid value ", value, " for 
option ", key);
+      SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, 
key);
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
+  } else if (std::strcmp(value, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES)) 
{
+    int64_t int_value = std::atol(value);
+    if (int_value <= 0) {
+      SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, 
key);
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
+
+    this->reader_.batch_size_hint_bytes_ = int_value;
   } else {
-    SetError(error, "%s%s", "[libq] Unknown statement option ", key);
+    SetError(error, "[libq] Unknown statement option '%s'", key);
     return ADBC_STATUS_NOT_IMPLEMENTED;
   }
   return ADBC_STATUS_OK;
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 0ff6cb89..62af2457 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -30,6 +30,9 @@
 #include "postgres_copy_reader.h"
 #include "postgres_type.h"
 
+#define ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES \
+  "adbc.postgresql.batch_size_hint_bytes"
+
 namespace adbcpq {
 class PostgresConnection;
 class PostgresStatement;
@@ -38,8 +41,15 @@ class PostgresStatement;
 class TupleReader final {
  public:
   TupleReader(PGconn* conn)
-      : conn_(conn), result_(nullptr), pgbuf_(nullptr), copy_reader_(nullptr) {
+      : conn_(conn),
+        result_(nullptr),
+        pgbuf_(nullptr),
+        copy_reader_(nullptr),
+        row_id_(-1),
+        batch_size_hint_bytes_(16777216) {
     StringBuilderInit(&error_builder_, 0);
+    data_.data.as_char = nullptr;
+    data_.size_bytes = 0;
   }
 
   int GetSchema(struct ArrowSchema* out);
@@ -57,6 +67,11 @@ class TupleReader final {
  private:
   friend class PostgresStatement;
 
+  int InitQueryAndFetchFirst(struct ArrowError* error);
+  int AppendRowAndFetchNext(struct ArrowError* error);
+  int BuildOutput(struct ArrowArray* out, struct ArrowError* error);
+  void ResetQuery();
+
   static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct 
ArrowSchema* out);
   static int GetNextTrampoline(struct ArrowArrayStream* self, struct 
ArrowArray* out);
   static const char* GetLastErrorTrampoline(struct ArrowArrayStream* self);
@@ -65,8 +80,11 @@ class TupleReader final {
   PGconn* conn_;
   PGresult* result_;
   char* pgbuf_;
+  struct ArrowBufferView data_;
   struct StringBuilder error_builder_;
   std::unique_ptr<PostgresCopyStreamReader> copy_reader_;
+  int64_t row_id_;
+  int64_t batch_size_hint_bytes_;
 };
 
 class PostgresStatement {

Reply via email to