This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch gh-1804
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git

commit 6145ed6f05a974d22b7c0ed01235816a7f724809
Author: David Li <li.david...@gmail.com>
AuthorDate: Thu May 2 21:09:42 2024 -0400

    fix(go/adbc/driver/snowflake): handle empty result sets
    
    Fixes #1804.
---
 c/validation/adbc_validation.h            |   2 +
 c/validation/adbc_validation_statement.cc |  24 +++++
 go/adbc/driver/snowflake/record_reader.go | 149 ++++++++++++++++--------------
 3 files changed, 105 insertions(+), 70 deletions(-)

diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 6c59d95e0..abe9a7686 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -407,6 +407,7 @@ class StatementTest {
   void TestSqlPrepareErrorNoQuery();
   void TestSqlPrepareErrorParamCountMismatch();
 
+  void TestSqlQueryEmpty();
   void TestSqlQueryInts();
   void TestSqlQueryFloats();
   void TestSqlQueryStrings();
@@ -504,6 +505,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) {                         
         \
     TestSqlPrepareErrorParamCountMismatch();                                   
         \
   }                                                                            
         \
+  TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); }                      
         \
   TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); }                        
         \
   TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); }                    
         \
   TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); }                  
         \
diff --git a/c/validation/adbc_validation_statement.cc 
b/c/validation/adbc_validation_statement.cc
index 59f3f3f9a..7da58c698 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -2062,6 +2062,30 @@ void 
StatementTest::TestSqlPrepareErrorParamCountMismatch() {
       ::testing::Not(IsOkStatus(&error)));
 }
 
+void StatementTest::TestSqlQueryEmpty() {
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42 WHERE 1=0", 
&error),
+              IsOkStatus(&error));
+  StreamReader reader;
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                        &reader.rows_affected, &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(reader.rows_affected,
+              ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+  ASSERT_EQ(1, reader.schema->n_children);
+
+  while (true) {
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    if (!reader.array->release) {
+      break;
+    }
+    ASSERT_EQ(0, reader.array->length);
+  }
+  ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
 void StatementTest::TestSqlQueryInts() {
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
   ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/go/adbc/driver/snowflake/record_reader.go 
b/go/adbc/driver/snowflake/record_reader.go
index bda3e8f70..5c7132220 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -571,23 +571,9 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
        }
 
        ch := make(chan arrow.Record, bufferSize)
-       r, err := batches[0].GetStream(ctx)
-       if err != nil {
-               return nil, errToAdbcErr(adbc.StatusIO, err)
-       }
-
-       rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
-       if err != nil {
-               return nil, adbc.Error{
-                       Msg:  err.Error(),
-                       Code: adbc.StatusInvalidState,
-               }
-       }
-
        group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
        ctx, cancelFn := context.WithCancel(ctx)
-
-       schema, recTransform := getTransformer(rr.Schema(), ld, 
useHighPrecision)
+       group.SetLimit(prefetchConcurrency)
 
        defer func() {
                if err != nil {
@@ -596,80 +582,103 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
                }
        }()
 
-       group.SetLimit(prefetchConcurrency)
-       group.Go(func() error {
-               defer rr.Release()
-               defer r.Close()
-               if len(batches) > 1 {
-                       defer close(ch)
-               }
-
-               for rr.Next() && ctx.Err() == nil {
-                       rec := rr.Record()
-                       rec, err = recTransform(ctx, rec)
-                       if err != nil {
-                               return err
-                       }
-                       ch <- rec
-               }
-               return rr.Err()
-       })
-
        chs := make([]chan arrow.Record, len(batches))
-       chs[0] = ch
        rdr := &reader{
                refCount: 1,
                chs:      chs,
                err:      nil,
                cancelFn: cancelFn,
-               schema:   schema,
        }
 
-       lastChannelIndex := len(chs) - 1
-       go func() {
-               for i, b := range batches[1:] {
-                       batch, batchIdx := b, i+1
-                       chs[batchIdx] = make(chan arrow.Record, bufferSize)
-                       group.Go(func() error {
-                               // close channels (except the last) so that 
Next can move on to the next channel properly
-                               if batchIdx != lastChannelIndex {
-                                       defer close(chs[batchIdx])
-                               }
+       if len(batches) > 0 {
+               r, err := batches[0].GetStream(ctx)
+               if err != nil {
+                       return nil, errToAdbcErr(adbc.StatusIO, err)
+               }
 
-                               rdr, err := batch.GetStream(ctx)
-                               if err != nil {
-                                       return err
-                               }
-                               defer rdr.Close()
+               rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
+               if err != nil {
+                       return nil, adbc.Error{
+                               Msg:  err.Error(),
+                               Code: adbc.StatusInvalidState,
+                       }
+               }
+
+               var recTransform recordTransformer
+               rdr.schema, recTransform = getTransformer(rr.Schema(), ld, 
useHighPrecision)
 
-                               rr, err := ipc.NewReader(rdr, 
ipc.WithAllocator(alloc))
+               group.Go(func() error {
+                       defer rr.Release()
+                       defer r.Close()
+                       if len(batches) > 1 {
+                               defer close(ch)
+                       }
+
+                       for rr.Next() && ctx.Err() == nil {
+                               rec := rr.Record()
+                               rec, err = recTransform(ctx, rec)
                                if err != nil {
                                        return err
                                }
-                               defer rr.Release()
+                               ch <- rec
+                       }
+                       return rr.Err()
+               })
+
+               chs[0] = ch
+
+               lastChannelIndex := len(chs) - 1
+               go func() {
+                       for i, b := range batches[1:] {
+                               batch, batchIdx := b, i+1
+                               chs[batchIdx] = make(chan arrow.Record, 
bufferSize)
+                               group.Go(func() error {
+                                       // close channels (except the last) so 
that Next can move on to the next channel properly
+                                       if batchIdx != lastChannelIndex {
+                                               defer close(chs[batchIdx])
+                                       }
 
-                               for rr.Next() && ctx.Err() == nil {
-                                       rec := rr.Record()
-                                       rec, err = recTransform(ctx, rec)
+                                       rdr, err := batch.GetStream(ctx)
                                        if err != nil {
                                                return err
                                        }
-                                       chs[batchIdx] <- rec
-                               }
+                                       defer rdr.Close()
 
-                               return rr.Err()
-                       })
-               }
+                                       rr, err := ipc.NewReader(rdr, 
ipc.WithAllocator(alloc))
+                                       if err != nil {
+                                               return err
+                                       }
+                                       defer rr.Release()
 
-               // place this here so that we always clean up, but they can't 
be in a
-               // separate goroutine. Otherwise we'll have a race condition 
between
-               // the call to wait and the calls to group.Go to kick off the 
jobs
-               // to perform the pre-fetching (GH-1283).
-               rdr.err = group.Wait()
-               // don't close the last channel until after the group is 
finished,
-               // so that Next() can only return after reader.err may have 
been set
-               close(chs[lastChannelIndex])
-       }()
+                                       for rr.Next() && ctx.Err() == nil {
+                                               rec := rr.Record()
+                                               rec, err = recTransform(ctx, 
rec)
+                                               if err != nil {
+                                                       return err
+                                               }
+                                               chs[batchIdx] <- rec
+                                       }
+
+                                       return rr.Err()
+                               })
+                       }
+
+                       // place this here so that we always clean up, but they 
can't be in a
+                       // separate goroutine. Otherwise we'll have a race 
condition between
+                       // the call to wait and the calls to group.Go to kick 
off the jobs
+                       // to perform the pre-fetching (GH-1283).
+                       rdr.err = group.Wait()
+                       // don't close the last channel until after the group 
is finished,
+                       // so that Next() can only return after reader.err may 
have been set
+                       close(chs[lastChannelIndex])
+               }()
+       } else {
+               schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
+               if err != nil {
+                       return nil, err
+               }
+               rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
+       }
 
        return rdr, nil
 }

Reply via email to