This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 3d12fad feat(go/adbc/driver/flightsql): support domain sockets (#516)
3d12fad is described below
commit 3d12fad1bae21029a8ff25604d6e65760c3f65bd
Author: David Li <[email protected]>
AuthorDate: Wed Mar 15 14:43:59 2023 -0400
feat(go/adbc/driver/flightsql): support domain sockets (#516)
Fixes #514.
---
go/adbc/driver/flightsql/flightsql_adbc.go | 7 ++-
go/adbc/driver/flightsql/flightsql_adbc_test.go | 75 ++++++++++++++++++++++++-
2 files changed, 79 insertions(+), 3 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go
b/go/adbc/driver/flightsql/flightsql_adbc.go
index d77369d..e78f72a 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -535,12 +535,17 @@ func getFlightClient(ctx context.Context, loc string, d
*database) (*flightsql.C
return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s",
loc, err), Code: adbc.StatusInvalidArgument}
}
creds := d.creds
+
+ target := uri.Host
if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
creds = insecure.NewCredentials()
+ } else if uri.Scheme == "grpc+unix" {
+ creds = insecure.NewCredentials()
+ target = "unix:" + uri.Path
}
dialOpts := append(d.dialOpts.opts,
grpc.WithTransportCredentials(creds))
- cl, err := flightsql.NewClient(uri.Host, nil, middleware, dialOpts...)
+ cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 0ff02ef..336679d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -32,6 +32,7 @@ import (
"math/big"
"net"
"os"
+ "path/filepath"
"strings"
"testing"
"time"
@@ -257,6 +258,7 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &TimeoutTestSuite{})
suite.Run(t, &TLSTests{Quirks: &FlightSQLQuirks{db: db}})
suite.Run(t, &ConnectionTests{})
+ suite.Run(t, &DomainSocketTests{db: db})
}
// Driver-specific tests
@@ -739,8 +741,9 @@ func (ts *TimeoutTestServer) DoGetStatement(ctx
context.Context, tkt flightsql.S
func (ts *TimeoutTestServer) DoPutCommandStatementUpdate(ctx context.Context,
cmd flightsql.StatementUpdate) (int64, error) {
if cmd.GetQuery() == "timeout" {
<-ctx.Done()
+ return -1, ctx.Err()
}
- return 0, arrow.ErrNotImplemented
+ return -1, arrow.ErrNotImplemented
}
func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd
flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo,
error) {
@@ -885,7 +888,7 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
func (ts *TimeoutTestSuite) TestDoPutTimeout() {
ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "5.1"))
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "1.1"))
stmt, err := ts.cnxn.NewStatement()
ts.Require().NoError(err)
@@ -1122,3 +1125,71 @@ func (suite *ConnectionTests) TestGetInfo() {
suite.Require().True(driverVersion)
suite.Require().True(driverArrowVersion)
}
+
+type DomainSocketTests struct {
+ suite.Suite
+
+ alloc *memory.CheckedAllocator
+ server flight.Server
+ service *example.SQLiteFlightSQLServer
+ db *sql.DB
+
+ Driver adbc.Driver
+ DB adbc.Database
+ Cnxn adbc.Connection
+ Stmt adbc.Statement
+ ctx context.Context
+}
+
+func (suite *DomainSocketTests) SetupSuite() {
+ suite.alloc = memory.NewCheckedAllocator(memory.DefaultAllocator)
+
+ tempDir, err := os.MkdirTemp("", "adbc-flight-sql-tests-*")
+ suite.NoError(err)
+ defer os.RemoveAll(tempDir)
+
+ listenSocket := filepath.Join(tempDir, "adbc.sock")
+
+ listener, err := net.Listen("unix", listenSocket)
+ suite.NoError(err)
+
+ suite.server = flight.NewServerWithMiddleware(nil)
+ suite.service, err = example.NewSQLiteFlightSQLServer(suite.db)
+ suite.NoError(err)
+
suite.server.RegisterFlightService(flightsql.NewFlightServer(suite.service))
+ suite.server.InitListener(listener)
+
+ go func() {
+ // Explicitly ignore error
+ _ = suite.server.Serve()
+ }()
+
+ suite.ctx = context.Background()
+ suite.Driver = driver.Driver{Alloc: suite.alloc}
+ suite.DB, err = suite.Driver.NewDatabase(map[string]string{
+ adbc.OptionKeyURI: "grpc+unix://" + listenSocket,
+ })
+ suite.Require().NoError(err)
+ suite.Cnxn, err = suite.DB.Open(suite.ctx)
+ suite.Require().NoError(err)
+ suite.Stmt, err = suite.Cnxn.NewStatement()
+ suite.Require().NoError(err)
+}
+
+func (suite *DomainSocketTests) TearDownSuite() {
+ suite.Require().NoError(suite.Stmt.Close())
+ suite.Require().NoError(suite.Cnxn.Close())
+ suite.server.Shutdown()
+ suite.alloc.AssertSize(suite.T(), 0)
+}
+
+func (suite *DomainSocketTests) TestSimpleQueryDomainSocket() {
+ suite.NoError(suite.Stmt.SetSqlQuery("SELECT 1"))
+ reader, _, err := suite.Stmt.ExecuteQuery(suite.ctx)
+ suite.NoError(err)
+ defer reader.Release()
+
+ for reader.Next() {
+ }
+ suite.NoError(reader.Err())
+}