This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch gh-1144
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git

commit e1553ddf16def724f4a6da5eedb8b044c287ab0f
Author: David Li <li.david...@gmail.com>
AuthorDate: Thu May 2 20:37:42 2024 -0400

    feat(go/adbc/driver/snowflake): support parameter binding
    
    Fixes #1144.
---
 c/driver/snowflake/snowflake_test.cc               |   2 +-
 .../driver/flightsql/flightsql_adbc_server_test.go |   2 +-
 go/adbc/driver/flightsql/flightsql_driver.go       |   2 +-
 go/adbc/driver/flightsql/logging.go                |   2 +-
 go/adbc/driver/snowflake/binding.go                | 138 +++++++++++++++++++++
 go/adbc/driver/snowflake/concat_reader.go          | 102 +++++++++++++++
 go/adbc/driver/snowflake/driver.go                 |   2 +-
 go/adbc/driver/snowflake/statement.go              |  23 +++-
 8 files changed, 265 insertions(+), 8 deletions(-)

diff --git a/c/driver/snowflake/snowflake_test.cc 
b/c/driver/snowflake/snowflake_test.cc
index 0fe07ecbd..a4d742491 100644
--- a/c/driver/snowflake/snowflake_test.cc
+++ b/c/driver/snowflake/snowflake_test.cc
@@ -146,7 +146,7 @@ class SnowflakeQuirks : public 
adbc_validation::DriverQuirks {
   bool supports_metadata_current_catalog() const override { return false; }
   bool supports_metadata_current_db_schema() const override { return false; }
   bool supports_partitioned_data() const override { return false; }
-  bool supports_dynamic_parameter_binding() const override { return false; }
+  bool supports_dynamic_parameter_binding() const override { return true; }
   bool supports_error_on_incompatible_schema() const override { return false; }
   bool ddl_implicit_commit_txn() const override { return true; }
   std::string db_schema() const override { return schema_; }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 3bf695f0b..7bfa08d90 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -24,6 +24,7 @@ import (
        "encoding/json"
        "errors"
        "fmt"
+       "maps"
        "net"
        "net/textproto"
        "os"
@@ -48,7 +49,6 @@ import (
        "github.com/apache/arrow/go/v17/arrow/memory"
        "github.com/golang/protobuf/ptypes/wrappers"
        "github.com/stretchr/testify/suite"
-       "golang.org/x/exp/maps"
        "google.golang.org/grpc"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/metadata"
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index db3e39772..175d685e4 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -32,13 +32,13 @@
 package flightsql
 
 import (
+       "maps"
        "net/url"
        "time"
 
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
        "github.com/apache/arrow/go/v17/arrow/memory"
-       "golang.org/x/exp/maps"
        "google.golang.org/grpc/metadata"
 )
 
diff --git a/go/adbc/driver/flightsql/logging.go 
b/go/adbc/driver/flightsql/logging.go
index 4fb12c411..187ac7078 100644
--- a/go/adbc/driver/flightsql/logging.go
+++ b/go/adbc/driver/flightsql/logging.go
@@ -21,9 +21,9 @@ import (
        "context"
        "io"
        "log/slog"
+       "maps"
        "time"
 
-       "golang.org/x/exp/maps"
        "golang.org/x/exp/slices"
        "google.golang.org/grpc"
        "google.golang.org/grpc/metadata"
diff --git a/go/adbc/driver/snowflake/binding.go 
b/go/adbc/driver/snowflake/binding.go
new file mode 100644
index 000000000..e79ecc8c4
--- /dev/null
+++ b/go/adbc/driver/snowflake/binding.go
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "database/sql"
+       "database/sql/driver"
+       "fmt"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+func convertArrowToNamedValue(batch arrow.Record, index int) 
([]driver.NamedValue, error) {
+       // see goTypeToSnowflake in gosnowflake
+       // technically, snowflake can bind an array of values at once, but
+       // only for INSERT, so we can't take advantage of that without
+       // analyzing the query ourselves
+       params := make([]driver.NamedValue, batch.NumCols())
+       for i, field := range batch.Schema().Fields() {
+               rawColumn := batch.Column(i)
+               params[i].Ordinal = i + 1
+               switch column := rawColumn.(type) {
+               case *array.Boolean:
+                       params[i].Value = sql.NullBool{
+                               Bool:  column.Value(index),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Float32:
+                       // Snowflake only recognizes float64
+                       params[i].Value = sql.NullFloat64{
+                               Float64: float64(column.Value(index)),
+                               Valid:   column.IsValid(index),
+                       }
+               case *array.Float64:
+                       params[i].Value = sql.NullFloat64{
+                               Float64: column.Value(index),
+                               Valid:   column.IsValid(index),
+                       }
+               case *array.Int8:
+                       // Snowflake only recognizes int64
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int16:
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int32:
+                       params[i].Value = sql.NullInt64{
+                               Int64: int64(column.Value(index)),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.Int64:
+                       params[i].Value = sql.NullInt64{
+                               Int64: column.Value(index),
+                               Valid: column.IsValid(index),
+                       }
+               case *array.String:
+                       params[i].Value = sql.NullString{
+                               String: column.Value(index),
+                               Valid:  column.IsValid(index),
+                       }
+               case *array.LargeString:
+                       params[i].Value = sql.NullString{
+                               String: column.Value(index),
+                               Valid:  column.IsValid(index),
+                       }
+               default:
+                       return nil, adbc.Error{
+                               Code: adbc.StatusNotImplemented,
+                               Msg:  fmt.Sprintf("[Snowflake] Unsupported bind 
param '%s' type %s", field.Name, field.Type.String()),
+                       }
+               }
+       }
+       return params, nil
+}
+
+type snowflakeBindReader struct {
+       doQuery      func([]driver.NamedValue) (array.RecordReader, error)
+       currentBatch arrow.Record
+       nextIndex    int64
+       // may be nil if we bound only a batch
+       stream array.RecordReader
+}
+
+func (r *snowflakeBindReader) Release() {
+       if r.currentBatch != nil {
+               r.currentBatch.Release()
+       }
+       if r.stream != nil {
+               r.stream.Release()
+       }
+}
+
+func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
+       for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
+               if r.stream != nil && r.stream.Next() {
+                       if r.currentBatch != nil {
+                               r.currentBatch.Release()
+                       }
+                       r.currentBatch = r.stream.Record()
+                       r.nextIndex = 0
+                       continue
+               } else if r.stream != nil && r.stream.Err() != nil {
+                       return nil, r.stream.Err()
+               } else {
+                       // end-of-stream
+                       return nil, nil
+               }
+       }
+
+       params, err := convertArrowToNamedValue(r.currentBatch, 
int(r.nextIndex))
+       if err != nil {
+               return nil, err
+       }
+       r.nextIndex++
+
+       return r.doQuery(params)
+}
diff --git a/go/adbc/driver/snowflake/concat_reader.go 
b/go/adbc/driver/snowflake/concat_reader.go
new file mode 100644
index 000000000..389bfb886
--- /dev/null
+++ b/go/adbc/driver/snowflake/concat_reader.go
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snowflake
+
+import (
+       "sync/atomic"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+)
+
+type readerIter interface {
+       Release()
+
+       Next() (array.RecordReader, error)
+}
+
+type concatReader struct {
+       refCount      atomic.Int64
+       readers       readerIter
+       currentReader array.RecordReader
+       schema        *arrow.Schema
+       err           error
+}
+
+func (r *concatReader) nextReader() {
+       if r.currentReader != nil {
+               r.currentReader.Release()
+               r.currentReader = nil
+       }
+       reader, err := r.readers.Next()
+       if err != nil {
+               r.err = err
+       } else {
+               // May be nil
+               r.currentReader = reader
+       }
+}
+func (r *concatReader) Init(readers readerIter) error {
+       r.readers = readers
+       r.refCount.Store(1)
+       r.nextReader()
+       if r.err != nil {
+               return r.err
+       } else if r.currentReader == nil {
+               r.err = adbc.Error{
+                       Code: adbc.StatusInternal,
+                       Msg:  "[Snowflake] No data in this stream",
+               }
+               return r.err
+       }
+       r.schema = r.currentReader.Schema()
+       return nil
+}
+func (r *concatReader) Retain() {
+       r.refCount.Add(1)
+}
+func (r *concatReader) Release() {
+       if r.refCount.Add(-1) == 0 {
+               r.readers.Release()
+               if r.currentReader != nil {
+                       r.currentReader.Release()
+               }
+       }
+}
+func (r *concatReader) Schema() *arrow.Schema {
+       if r.schema == nil {
+               panic("did not call concatReader.Init")
+       }
+       return r.schema
+}
+func (r *concatReader) Next() bool {
+       for r.currentReader != nil && !r.currentReader.Next() {
+               r.nextReader()
+       }
+       if r.currentReader == nil || r.err != nil {
+               return false
+       }
+       return true
+}
+func (r *concatReader) Record() arrow.Record {
+       return r.currentReader.Record()
+}
+func (r *concatReader) Err() error {
+       return r.err
+}
diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index da49a6097..a49dd13b8 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -19,6 +19,7 @@ package snowflake
 
 import (
        "errors"
+       "maps"
        "runtime/debug"
        "strings"
 
@@ -26,7 +27,6 @@ import (
        "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
        "github.com/apache/arrow/go/v17/arrow/memory"
        "github.com/snowflakedb/gosnowflake"
-       "golang.org/x/exp/maps"
 )
 
 const (
diff --git a/go/adbc/driver/snowflake/statement.go 
b/go/adbc/driver/snowflake/statement.go
index f61db8f06..283862ce8 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -19,6 +19,7 @@ package snowflake
 
 import (
        "context"
+       "database/sql/driver"
        "fmt"
        "strconv"
        "strings"
@@ -463,10 +464,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) 
(array.RecordReader, int6
        // concatenate RecordReaders which doesn't exist yet. let's put
        // that off for now.
        if st.streamBind != nil || st.bound != nil {
-               return nil, -1, adbc.Error{
-                       Msg:  "executing non-bulk ingest with bound params not 
yet implemented",
-                       Code: adbc.StatusNotImplemented,
+               bind := snowflakeBindReader{
+                       doQuery: func(params []driver.NamedValue) 
(array.RecordReader, error) {
+                               loader, err := st.cnxn.cn.QueryArrowStream(ctx, 
st.query, params...)
+                               if err != nil {
+                                       return nil, 
errToAdbcErr(adbc.StatusInternal, err)
+                               }
+                               return newRecordReader(ctx, st.alloc, loader, 
st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
+                       },
+                       currentBatch: st.bound,
+                       stream:       st.streamBind,
+               }
+               st.bound = nil
+               st.streamBind = nil
+
+               rdr := concatReader{}
+               err := rdr.Init(&bind)
+               if err != nil {
+                       return nil, -1, err
                }
+               return &rdr, -1, nil
        }
 
        loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)

Reply via email to