This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 631f068  fix(go/adbc/driver/flightsql): guard against inconsistent 
schemas (#409)
631f068 is described below

commit 631f068d794e4a3eb2299eaa98d12618ee6d2c90
Author: David Li <[email protected]>
AuthorDate: Fri Feb 3 11:20:54 2023 -0500

    fix(go/adbc/driver/flightsql): guard against inconsistent schemas (#409)
    
    In case the FlightInfo schema doesn't match the DoGet schema, return an
    error instead of allowing the client to misinterpret the result.
---
 .github/workflows/native-unix.yml              |  4 +-
 go/adbc/driver/flightsql/record_reader.go      | 54 ++++++++++++++++++++++++++
 go/adbc/driver/flightsql/record_reader_test.go | 29 ++++++++++++++
 3 files changed, 85 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/native-unix.yml 
b/.github/workflows/native-unix.yml
index 35b480e..1211936 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -392,7 +392,7 @@ jobs:
           cache: true
           cache-dependency-path: go/adbc/go.sum
       - name: Install staticcheck
-        run: go install honnef.co/go/tools/cmd/staticcheck@latest
+        run: go install honnef.co/go/tools/cmd/[email protected]
       - name: Go Build
         run: |
           ./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
@@ -450,7 +450,7 @@ jobs:
       - name: Install staticcheck
         shell: bash -l {0}
         if: ${{ !contains('macos-latest', matrix.os) }}
-        run: go install honnef.co/go/tools/cmd/staticcheck@latest
+        run: go install honnef.co/go/tools/cmd/[email protected]
 
       - uses: actions/download-artifact@v3
         with:
diff --git a/go/adbc/driver/flightsql/record_reader.go 
b/go/adbc/driver/flightsql/record_reader.go
index 042c661..9e204c0 100644
--- a/go/adbc/driver/flightsql/record_reader.go
+++ b/go/adbc/driver/flightsql/record_reader.go
@@ -19,6 +19,7 @@ package flightsql
 
 import (
        "context"
+       "fmt"
        "sync/atomic"
 
        "github.com/apache/arrow-adbc/go/adbc"
@@ -116,6 +117,7 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, cl *flightsql.
 
        lastChannelIndex := len(chs) - 1
 
+       referenceSchema := removeSchemaMetadata(schema)
        for i, ep := range endpoints {
                endpoint := ep
                endpointIndex := i
@@ -132,6 +134,11 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, cl *flightsql.
                        }
                        defer rdr.Release()
 
+                       streamSchema := removeSchemaMetadata(rdr.Schema())
+                       if !streamSchema.Equal(referenceSchema) {
+                               return fmt.Errorf("endpoint %d returned 
inconsistent schema: expected %s but got %s", endpointIndex, 
referenceSchema.String(), streamSchema.String())
+                       }
+
                        for rdr.Next() && ctx.Err() == nil {
                                rec := rdr.Record()
                                rec.Retain()
@@ -201,3 +208,50 @@ func (r *reader) Schema() *arrow.Schema {
 func (r *reader) Record() arrow.Record {
        return r.rec
 }
+
+func removeSchemaMetadata(schema *arrow.Schema) *arrow.Schema {
+       fields := make([]arrow.Field, len(schema.Fields()))
+       for i, field := range schema.Fields() {
+               fields[i] = removeFieldMetadata(&field)
+       }
+       return arrow.NewSchema(fields, nil)
+}
+
+func removeFieldMetadata(field *arrow.Field) arrow.Field {
+       fieldType := field.Type
+
+       if nestedType, ok := field.Type.(arrow.NestedType); ok {
+               childFields := make([]arrow.Field, len(nestedType.Fields()))
+               for i, field := range nestedType.Fields() {
+                       childFields[i] = removeFieldMetadata(&field)
+               }
+
+               switch ty := field.Type.(type) {
+               case *arrow.DenseUnionType:
+                       fieldType = arrow.DenseUnionOf(childFields, 
ty.TypeCodes())
+               case *arrow.FixedSizeListType:
+                       fieldType = arrow.FixedSizeListOfField(ty.Len(), 
childFields[0])
+               case *arrow.ListType:
+                       fieldType = arrow.ListOfField(childFields[0])
+               case *arrow.LargeListType:
+                       fieldType = arrow.LargeListOfField(childFields[0])
+               case *arrow.MapType:
+                       mapType := arrow.MapOf(childFields[0].Type, 
childFields[1].Type)
+                       mapType.KeysSorted = ty.KeysSorted
+                       fieldType = mapType
+               case *arrow.SparseUnionType:
+                       fieldType = arrow.SparseUnionOf(childFields, 
ty.TypeCodes())
+               case *arrow.StructType:
+                       fieldType = arrow.StructOf(childFields...)
+               default:
+                       // XXX: ignore it
+               }
+       }
+
+       return arrow.Field{
+               Name:     field.Name,
+               Type:     fieldType,
+               Nullable: field.Nullable,
+               Metadata: arrow.Metadata{},
+       }
+}
diff --git a/go/adbc/driver/flightsql/record_reader_test.go 
b/go/adbc/driver/flightsql/record_reader_test.go
index fd4d31a..c210122 100644
--- a/go/adbc/driver/flightsql/record_reader_test.go
+++ b/go/adbc/driver/flightsql/record_reader_test.go
@@ -282,6 +282,35 @@ func (suite *RecordReaderTests) TestNoSchema() {
        suite.NoError(reader.Err())
 }
 
+func (suite *RecordReaderTests) TestSchemaEndpointMismatch() {
+       location := "grpc://" + suite.server.Addr().String()
+       badSchema := arrow.NewSchema([]arrow.Field{
+               {Name: "epIndex", Type: arrow.PrimitiveTypes.Int32},
+               {Name: "batchIndex", Type: arrow.PrimitiveTypes.Int32},
+       }, nil)
+       info := flight.FlightInfo{
+               Schema: flight.SerializeSchema(badSchema, suite.alloc),
+               Endpoint: []*flight.FlightEndpoint{
+                       {
+                               Ticket:   &flight.Ticket{Ticket: []byte{0}},
+                               Location: []*flight.Location{{Uri: location}},
+                       },
+                       {
+                               Ticket:   &flight.Ticket{Ticket: []byte{1}},
+                               Location: []*flight.Location{{Uri: location}},
+                       },
+               },
+       }
+
+       reader, err := newRecordReader(context.Background(), suite.alloc, 
suite.cl, &info, suite.clCache, 3)
+       suite.NoError(err)
+       defer reader.Release()
+
+       suite.True(reader.Schema().Equal(badSchema))
+       suite.False(reader.Next())
+       suite.ErrorContains(reader.Err(), "returned inconsistent schema: 
expected schema:")
+}
+
 func (suite *RecordReaderTests) TestOrdering() {
        // Info with a ton of endpoints; we want to make sure data comes back 
in order
        location := "grpc://" + suite.server.Addr().String()

Reply via email to