This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d12fad  feat(go/adbc/driver/flightsql): support domain sockets (#516)
3d12fad is described below

commit 3d12fad1bae21029a8ff25604d6e65760c3f65bd
Author: David Li <[email protected]>
AuthorDate: Wed Mar 15 14:43:59 2023 -0400

    feat(go/adbc/driver/flightsql): support domain sockets (#516)
    
    Fixes #514.
---
 go/adbc/driver/flightsql/flightsql_adbc.go      |  7 ++-
 go/adbc/driver/flightsql/flightsql_adbc_test.go | 75 ++++++++++++++++++++++++-
 2 files changed, 79 insertions(+), 3 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index d77369d..e78f72a 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -535,12 +535,17 @@ func getFlightClient(ctx context.Context, loc string, d 
*database) (*flightsql.C
                return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s", 
loc, err), Code: adbc.StatusInvalidArgument}
        }
        creds := d.creds
+
+       target := uri.Host
        if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
                creds = insecure.NewCredentials()
+       } else if uri.Scheme == "grpc+unix" {
+               creds = insecure.NewCredentials()
+               target = "unix:" + uri.Path
        }
        dialOpts := append(d.dialOpts.opts, 
grpc.WithTransportCredentials(creds))
 
-       cl, err := flightsql.NewClient(uri.Host, nil, middleware, dialOpts...)
+       cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
        if err != nil {
                return nil, adbc.Error{
                        Msg:  err.Error(),
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 0ff02ef..336679d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -32,6 +32,7 @@ import (
        "math/big"
        "net"
        "os"
+       "path/filepath"
        "strings"
        "testing"
        "time"
@@ -257,6 +258,7 @@ func TestADBCFlightSQL(t *testing.T) {
        suite.Run(t, &TimeoutTestSuite{})
        suite.Run(t, &TLSTests{Quirks: &FlightSQLQuirks{db: db}})
        suite.Run(t, &ConnectionTests{})
+       suite.Run(t, &DomainSocketTests{db: db})
 }
 
 // Driver-specific tests
@@ -739,8 +741,9 @@ func (ts *TimeoutTestServer) DoGetStatement(ctx 
context.Context, tkt flightsql.S
 func (ts *TimeoutTestServer) DoPutCommandStatementUpdate(ctx context.Context, 
cmd flightsql.StatementUpdate) (int64, error) {
        if cmd.GetQuery() == "timeout" {
                <-ctx.Done()
+               return -1, ctx.Err()
        }
-       return 0, arrow.ErrNotImplemented
+       return -1, arrow.ErrNotImplemented
 }
 
 func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd 
flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, 
error) {
@@ -885,7 +888,7 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
 
 func (ts *TimeoutTestSuite) TestDoPutTimeout() {
        ts.NoError(ts.cnxn.(adbc.PostInitOptions).
-               SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "5.1"))
+               SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "1.1"))
 
        stmt, err := ts.cnxn.NewStatement()
        ts.Require().NoError(err)
@@ -1122,3 +1125,71 @@ func (suite *ConnectionTests) TestGetInfo() {
        suite.Require().True(driverVersion)
        suite.Require().True(driverArrowVersion)
 }
+
+type DomainSocketTests struct {
+       suite.Suite
+
+       alloc   *memory.CheckedAllocator
+       server  flight.Server
+       service *example.SQLiteFlightSQLServer
+       db      *sql.DB
+
+       Driver adbc.Driver
+       DB     adbc.Database
+       Cnxn   adbc.Connection
+       Stmt   adbc.Statement
+       ctx    context.Context
+}
+
+func (suite *DomainSocketTests) SetupSuite() {
+       suite.alloc = memory.NewCheckedAllocator(memory.DefaultAllocator)
+
+       tempDir, err := os.MkdirTemp("", "adbc-flight-sql-tests-*")
+       suite.NoError(err)
+       defer os.RemoveAll(tempDir)
+
+       listenSocket := filepath.Join(tempDir, "adbc.sock")
+
+       listener, err := net.Listen("unix", listenSocket)
+       suite.NoError(err)
+
+       suite.server = flight.NewServerWithMiddleware(nil)
+       suite.service, err = example.NewSQLiteFlightSQLServer(suite.db)
+       suite.NoError(err)
+       
suite.server.RegisterFlightService(flightsql.NewFlightServer(suite.service))
+       suite.server.InitListener(listener)
+
+       go func() {
+               // Explicitly ignore error
+               _ = suite.server.Serve()
+       }()
+
+       suite.ctx = context.Background()
+       suite.Driver = driver.Driver{Alloc: suite.alloc}
+       suite.DB, err = suite.Driver.NewDatabase(map[string]string{
+               adbc.OptionKeyURI: "grpc+unix://" + listenSocket,
+       })
+       suite.Require().NoError(err)
+       suite.Cnxn, err = suite.DB.Open(suite.ctx)
+       suite.Require().NoError(err)
+       suite.Stmt, err = suite.Cnxn.NewStatement()
+       suite.Require().NoError(err)
+}
+
+func (suite *DomainSocketTests) TearDownSuite() {
+       suite.Require().NoError(suite.Stmt.Close())
+       suite.Require().NoError(suite.Cnxn.Close())
+       suite.server.Shutdown()
+       suite.alloc.AssertSize(suite.T(), 0)
+}
+
+func (suite *DomainSocketTests) TestSimpleQueryDomainSocket() {
+       suite.NoError(suite.Stmt.SetSqlQuery("SELECT 1"))
+       reader, _, err := suite.Stmt.ExecuteQuery(suite.ctx)
+       suite.NoError(err)
+       defer reader.Release()
+
+       for reader.Next() {
+       }
+       suite.NoError(reader.Err())
+}

Reply via email to