This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new ddbfaecc fix(go/adbc/driver/flightsql): Have GetTableSchema check for 
table name match instead of the first schema it receives (#980)
ddbfaecc is described below

commit ddbfaeccba2be01fe0e54cacd29c058fdf5359e3
Author: Solomon Choe <[email protected]>
AuthorDate: Tue Aug 22 11:42:44 2023 -0700

    fix(go/adbc/driver/flightsql): Have GetTableSchema check for table name 
match instead of the first schema it receives (#980)
    
    Fixes #934.
---
 go/adbc/driver/flightsql/flightsql_adbc.go         | 42 +++++++---
 .../driver/flightsql/flightsql_adbc_server_test.go | 92 ++++++++++++++++++++++
 2 files changed, 122 insertions(+), 12 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index 1ae99a6a..e00310cf 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -1231,24 +1231,42 @@ func (c *cnxn) GetTableSchema(ctx context.Context, 
catalog *string, dbSchema *st
                return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
        }
 
-       if rec.NumRows() == 0 {
+       numRows := rec.NumRows()
+       switch {
+       case numRows == 0:
                return nil, adbc.Error{
                        Code: adbc.StatusNotFound,
                }
+       case numRows > math.MaxInt32:
+               return nil, adbc.Error{
+                       Msg:  "[Flight SQL] GetTableSchema cannot handle tables 
with number of rows > 2^31 - 1",
+                       Code: adbc.StatusNotImplemented,
+               }
        }
 
-       // returned schema should be
-       //    0: catalog_name: utf8
-       //    1: db_schema_name: utf8
-       //    2: table_name: utf8 not null
-       //    3: table_type: utf8 not null
-       //    4: table_schema: bytes not null
-       schemaBytes := rec.Column(4).(*array.Binary).Value(0)
-       s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
-       if err != nil {
-               return nil, adbcFromFlightStatus(err, "GetTableSchema")
+       var s *arrow.Schema
+       for i := 0; i < int(numRows); i++ {
+               currentTableName := rec.Column(2).(*array.String).Value(i)
+               if currentTableName == tableName {
+                       // returned schema should be
+                       //    0: catalog_name: utf8
+                       //    1: db_schema_name: utf8
+                       //    2: table_name: utf8 not null
+                       //    3: table_type: utf8 not null
+                       //    4: table_schema: bytes not null
+                       schemaBytes := rec.Column(4).(*array.Binary).Value(i)
+                       s, err = flight.DeserializeSchema(schemaBytes, 
c.db.alloc)
+                       if err != nil {
+                               return nil, adbcFromFlightStatus(err, 
"GetTableSchema")
+                       }
+                       return s, nil
+               }
+       }
+
+       return s, adbc.Error{
+               Msg:  "[Flight SQL] GetTableSchema could not find a table with 
a matching schema",
+               Code: adbc.StatusNotFound,
        }
-       return s, nil
 }
 
 // GetTableTypes returns a list of the table types in the database.
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index dd6171c4..d8af6a65 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -35,6 +35,7 @@ import (
        "github.com/apache/arrow/go/v13/arrow/array"
        "github.com/apache/arrow/go/v13/arrow/flight"
        "github.com/apache/arrow/go/v13/arrow/flight/flightsql"
+       "github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
        "github.com/apache/arrow/go/v13/arrow/memory"
        "github.com/stretchr/testify/suite"
        "golang.org/x/exp/maps"
@@ -107,6 +108,10 @@ func TestDataType(t *testing.T) {
        suite.Run(t, &DataTypeTests{})
 }
 
+func TestMultiTable(t *testing.T) {
+       suite.Run(t, &MultiTableTests{})
+}
+
 // ---- AuthN Tests --------------------
 
 type AuthnTestServer struct {
@@ -627,3 +632,90 @@ func (suite *DataTypeTests) TestListInt() {
 func (suite *DataTypeTests) TestMapIntInt() {
        suite.DoTestCase("map[int]int", SchemaMapIntInt)
 }
+
+// ---- Multi Table Tests --------------------
+
+type MultiTableTestServer struct {
+       flightsql.BaseServer
+}
+
+func (server *MultiTableTestServer) GetFlightInfoStatement(ctx 
context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
+       query := cmd.GetQuery()
+       tkt, err := flightsql.CreateStatementQueryTicket([]byte(query))
+       if err != nil {
+               return nil, err
+       }
+
+       return &flight.FlightInfo{
+               Endpoint:         []*flight.FlightEndpoint{{Ticket: 
&flight.Ticket{Ticket: tkt}}},
+               FlightDescriptor: desc,
+               TotalRecords:     -1,
+               TotalBytes:       -1,
+       }, nil
+}
+
+func (server *MultiTableTestServer) GetFlightInfoTables(ctx context.Context, 
cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, 
error) {
+       schema := schema_ref.Tables
+       if cmd.GetIncludeSchema() {
+               schema = schema_ref.TablesWithIncludedSchema
+       }
+       server.Alloc = memory.NewCheckedAllocator(memory.DefaultAllocator)
+       info := &flight.FlightInfo{
+               Endpoint: []*flight.FlightEndpoint{
+                       {Ticket: &flight.Ticket{Ticket: desc.Cmd}},
+               },
+               FlightDescriptor: desc,
+               Schema:           flight.SerializeSchema(schema, server.Alloc),
+               TotalRecords:     -1,
+               TotalBytes:       -1,
+       }
+
+       return info, nil
+}
+
+func (server *MultiTableTestServer) DoGetTables(ctx context.Context, cmd 
flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+       bldr := array.NewRecordBuilder(server.Alloc, adbc.GetTableSchemaSchema)
+
+       bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+       bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+       bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"tbl1", 
"tbl2"}, nil)
+       bldr.Field(3).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+
+       sc1 := arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+       sc2 := arrow.NewSchema([]arrow.Field{{Name: "b", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+       buf1 := flight.SerializeSchema(sc1, server.Alloc)
+       buf2 := flight.SerializeSchema(sc2, server.Alloc)
+
+       bldr.Field(4).(*array.BinaryBuilder).AppendValues([][]byte{buf1, buf2}, 
nil)
+       defer bldr.Release()
+
+       rec := bldr.NewRecord()
+
+       ch := make(chan flight.StreamChunk)
+       go func() {
+               defer close(ch)
+               ch <- flight.StreamChunk{
+                       Data: rec,
+                       Desc: nil,
+                       Err:  nil,
+               }
+       }()
+       return adbc.GetTableSchemaSchema, ch, nil
+}
+
+type MultiTableTests struct {
+       ServerBasedTests
+}
+
+func (suite *MultiTableTests) SetupSuite() {
+       suite.DoSetupSuite(&MultiTableTestServer{}, nil, map[string]string{})
+}
+
+// Regression test for https://github.com/apache/arrow-adbc/issues/934
+func (suite *MultiTableTests) TestGetTableSchema() {
+       actualSchema, err := suite.cnxn.GetTableSchema(context.Background(), 
nil, nil, "tbl2")
+       suite.NoError(err)
+
+       expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+       suite.Equal(expectedSchema, actualSchema)
+}

Reply via email to