This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 2e044e85a fix(go/adbc/driver/snowflake): handle empty result sets
(#1805)
2e044e85a is described below
commit 2e044e85a70f42130b70532fcd995e76a2671933
Author: David Li <[email protected]>
AuthorDate: Sat May 4 20:13:52 2024 +0900
fix(go/adbc/driver/snowflake): handle empty result sets (#1805)
Fixes #1804.
---
c/driver/flightsql/dremio_flightsql_test.cc | 1 +
c/validation/adbc_validation.h | 2 ++
c/validation/adbc_validation_statement.cc | 35 ++++++++++++++++++++
go/adbc/driver/snowflake/driver_test.go | 19 +++++++++++
go/adbc/driver/snowflake/record_reader.go | 50 +++++++++++++++++------------
5 files changed, 87 insertions(+), 20 deletions(-)
diff --git a/c/driver/flightsql/dremio_flightsql_test.cc
b/c/driver/flightsql/dremio_flightsql_test.cc
index 8c59eb4a2..acc068279 100644
--- a/c/driver/flightsql/dremio_flightsql_test.cc
+++ b/c/driver/flightsql/dremio_flightsql_test.cc
@@ -92,6 +92,7 @@ class DremioFlightSqlStatementTest : public ::testing::Test,
void TestSqlIngestColumnEscaping() {
GTEST_SKIP() << "Column escaping not implemented";
}
+ void TestSqlQueryEmpty() { GTEST_SKIP() << "Dremio doesn't support
'acceptPut'"; }
void TestSqlQueryRowsAffectedDelete() {
GTEST_SKIP() << "Cannot query rows affected in delete (not implemented)";
}
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 6c59d95e0..abe9a7686 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -407,6 +407,7 @@ class StatementTest {
void TestSqlPrepareErrorNoQuery();
void TestSqlPrepareErrorParamCountMismatch();
+ void TestSqlQueryEmpty();
void TestSqlQueryInts();
void TestSqlQueryFloats();
void TestSqlQueryStrings();
@@ -504,6 +505,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) {
\
TestSqlPrepareErrorParamCountMismatch();
\
}
\
+ TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); }
\
TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); }
\
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); }
\
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); }
\
diff --git a/c/validation/adbc_validation_statement.cc
b/c/validation/adbc_validation_statement.cc
index 59f3f3f9a..333baf141 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -2062,6 +2062,41 @@ void
StatementTest::TestSqlPrepareErrorParamCountMismatch() {
::testing::Not(IsOkStatus(&error)));
}
+void StatementTest::TestSqlQueryEmpty() {
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+
+ ASSERT_THAT(quirks()->DropTable(&connection, "QUERYEMPTY", &error),
IsOkStatus(&error));
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(&statement, "CREATE TABLE QUERYEMPTY (FOO
INT)", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+ IsOkStatus(&error));
+
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(&statement, "SELECT * FROM QUERYEMPTY WHERE
1=0", &error),
+ IsOkStatus(&error));
+ {
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(reader.rows_affected,
+ ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_EQ(1, reader.schema->n_children);
+
+ while (true) {
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ if (!reader.array->release) {
+ break;
+ }
+ ASSERT_EQ(0, reader.array->length);
+ }
+ }
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
void StatementTest::TestSqlQueryInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index de175c3a0..af94e6108 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -2031,3 +2031,22 @@ func (suite *SnowflakeTests) TestMetadataOnlyQuery() {
// all the rows from each record in the stream.
suite.Equal(n, recv)
}
+
+func (suite *SnowflakeTests) TestEmptyResultSet() {
+ // regression test for apache/arrow-adbc#1804
+ // this would previously crash
+ suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT 42 WHERE 1=0`))
+ rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ recv := int64(0)
+ for rdr.Next() {
+ recv += rdr.Record().NumRows()
+ }
+
+ // verify that we got the exepected number of rows if we sum up
+ // all the rows from each record in the stream.
+ suite.Equal(n, recv)
+ suite.Equal(recv, int64(0))
+}
diff --git a/go/adbc/driver/snowflake/record_reader.go
b/go/adbc/driver/snowflake/record_reader.go
index bda3e8f70..e404f116d 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -571,6 +571,34 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, ld gosnowflake
}
ch := make(chan arrow.Record, bufferSize)
+ group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
+ ctx, cancelFn := context.WithCancel(ctx)
+ group.SetLimit(prefetchConcurrency)
+
+ defer func() {
+ if err != nil {
+ close(ch)
+ cancelFn()
+ }
+ }()
+
+ chs := make([]chan arrow.Record, len(batches))
+ rdr := &reader{
+ refCount: 1,
+ chs: chs,
+ err: nil,
+ cancelFn: cancelFn,
+ }
+
+ if len(batches) == 0 {
+ schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
+ if err != nil {
+ return nil, err
+ }
+ rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
+ return rdr, nil
+ }
+
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
@@ -584,19 +612,9 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, ld gosnowflake
}
}
- group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
- ctx, cancelFn := context.WithCancel(ctx)
-
- schema, recTransform := getTransformer(rr.Schema(), ld,
useHighPrecision)
+ var recTransform recordTransformer
+ rdr.schema, recTransform = getTransformer(rr.Schema(), ld,
useHighPrecision)
- defer func() {
- if err != nil {
- close(ch)
- cancelFn()
- }
- }()
-
- group.SetLimit(prefetchConcurrency)
group.Go(func() error {
defer rr.Release()
defer r.Close()
@@ -615,15 +633,7 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, ld gosnowflake
return rr.Err()
})
- chs := make([]chan arrow.Record, len(batches))
chs[0] = ch
- rdr := &reader{
- refCount: 1,
- chs: chs,
- err: nil,
- cancelFn: cancelFn,
- schema: schema,
- }
lastChannelIndex := len(chs) - 1
go func() {