This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 6648c1d2 fix(arrow/cdata, arrow/flight): fix handling of colons in 
values and fix potential panics (#761)
6648c1d2 is described below

commit 6648c1d202bd3d844d532b9167ad292b093a05e6
Author: Sebastiaan van Stijn <[email protected]>
AuthorDate: Mon Apr 13 03:10:40 2026 +0200

    fix(arrow/cdata, arrow/flight): fix handling of colons in values and fix 
potential panics (#761)
    
    ### fix(arrow/cdata): importSchema: handle colons in values
    
    Use strings.Cut, both as an optimization, and to prevent values
    containing
    a colon (e.g. "tsu:+01:00") from being mis-interpreted. This patch also
    removes some intermediate variables, and redundant handling of
    "defaulttz",
    which assigned an empty string if the value was empty.
    
    
    ### fix(arrow/cdata): importSchema: fix potential panic and optimize
    
    Rewrite the code with strings.Cut and strings.SplitSeq to reduce
    allocations, and to fix a potential panic.
    
    Before this patch, the code would panic if a colon was missing;
    
    CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors
    ./arrow/cdata/
        --- FAIL: TestUnionSchemaErrors (0.00s)
            --- FAIL: TestUnionSchemaErrors/+us (0.00s)
    panic: runtime error: index out of range [1] with length 1 [recovered,
    repanicked]
    
        goroutine 9 [running]:
        testing.tRunner.func1.2({0x7fc7c0, 0x4000026ab0})
            /usr/local/go/src/testing/testing.go:1872 +0x190
        testing.tRunner.func1()
            /usr/local/go/src/testing/testing.go:1875 +0x31c
        panic({0x7fc7c0?, 0x4000026ab0?})
            /usr/local/go/src/runtime/panic.go:783 +0x120
    github.com/apache/arrow-go/v18/arrow/cdata.importSchema(0x40001c36d0)
            /foo/arrow/cdata/cdata.go:306 +0x1520
        github.com/apache/arrow-go/v18/arrow/cdata.ImportCArrowField(...)
            /foo/arrow/cdata/interface.go:43
    
    
github.com/apache/arrow-go/v18/arrow/cdata.TestUnionSchemaErrors.func1(0x40000e0a80)
            /foo/arrow/cdata/cdata_test.go:188 +0xb0
        testing.tRunner(0x40000e0a80, 0x400020c060)
            /usr/local/go/src/testing/testing.go:1934 +0xc8
        created by testing.(*T).Run in goroutine 8
            /usr/local/go/src/testing/testing.go:1997 +0x364
        FAIL        github.com/apache/arrow-go/v18/arrow/cdata      0.007s
        FAIL
    
    With this patch applied, the code handles the invalid value gracefully;
    
    CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors
    ./arrow/cdata/
        === RUN   TestUnionSchemaErrors
        === RUN   TestUnionSchemaErrors/+us
        === RUN   TestUnionSchemaErrors/+ud
        --- PASS: TestUnionSchemaErrors (0.00s)
            --- PASS: TestUnionSchemaErrors/+us (0.00s)
            --- PASS: TestUnionSchemaErrors/+ud (0.00s)
        PASS
        ok          github.com/apache/arrow-go/v18/arrow/cdata      0.003s
    
    
    ### fix(arrow/flight): avoid panic on malformed authorization header
    
    Rewrite the code with strings.Cut for readability and ensue missing
    credentials
    in Basic/Bearer authorization headers return Unauthenticated instead of
    panicking.
    
    Before this patch, the code could panic;
    
        go test -run TestBasicAuthMissingCredential ./arrow/flight/
        panic: runtime error: index out of range [1] with length 1
    
        goroutine 7 [running]:
    
    
github.com/apache/arrow-go/v18/arrow/flight_test.TestBasicAuthMissingCredential.CreateServerBasicAuthMiddleware.createServerBearerTokenStreamInterceptor.func3({0x8d8240,
    0x40002134a0}, {0xa73e68, 0x40000e2000}, 0x40000100c0, 0x96b628)
            /foo/arrow/flight/server_auth.go:188 +0x49c
        ....
    
    With this patch applied, the code handles the invalid header gracefully;
    
        go test -run TestBasicAuthMissingCredential ./arrow/flight/
        ok          github.com/apache/arrow-go/v18/arrow/flight     0.010s
    
    
    
    
    ### Rationale for this change
    
    
    ### What changes are included in this PR?
    
    
    ### Are these changes tested?
    
    
    ### Are there any user-facing changes?
    
    ---------
    
    Signed-off-by: Sebastiaan van Stijn <[email protected]>
---
 arrow/cdata/cdata.go                   | 41 ++++++++++----------------------
 arrow/cdata/cdata_test.go              | 23 +++++++++++++++++-
 arrow/flight/basic_auth_flight_test.go | 43 ++++++++++++++++++++++++++++++++++
 arrow/flight/server_auth.go            | 24 +++++++++++--------
 4 files changed, 92 insertions(+), 39 deletions(-)

diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go
index 63419469..d8eeb5b2 100644
--- a/arrow/cdata/cdata.go
+++ b/arrow/cdata/cdata.go
@@ -202,41 +202,23 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, 
err error) {
        }
 
        // handle types with params via colon
-       typs := strings.Split(f, ":")
-       defaulttz := ""
-       switch typs[0] {
+       switch key, val, _ := strings.Cut(f, ":"); key {
        case "tss":
-               tz := typs[1]
-               if len(typs[1]) == 0 {
-                       tz = defaulttz
-               }
-               dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: tz}
+               dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: val}
        case "tsm":
-               tz := typs[1]
-               if len(typs[1]) == 0 {
-                       tz = defaulttz
-               }
-               dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: tz}
+               dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: 
val}
        case "tsu":
-               tz := typs[1]
-               if len(typs[1]) == 0 {
-                       tz = defaulttz
-               }
-               dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: tz}
+               dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: 
val}
        case "tsn":
-               tz := typs[1]
-               if len(typs[1]) == 0 {
-                       tz = defaulttz
-               }
-               dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: tz}
+               dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: val}
        case "w": // fixed size binary is "w:##" where ## is the byteWidth
-               byteWidth, err := strconv.Atoi(typs[1])
+               byteWidth, err := strconv.Atoi(val)
                if err != nil {
                        return ret, err
                }
                dt = &arrow.FixedSizeBinaryType{ByteWidth: byteWidth}
        case "d": // decimal types are d:<precision>,<scale>[,<bitsize>] size 
is assumed 128 if left out
-               props := typs[1]
+               props := val
                propList := strings.Split(props, ",")
                bitwidth := 128
                var precision, scale int
@@ -317,9 +299,12 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, 
err error) {
                                return
                        }
 
-                       codes := strings.Split(strings.Split(f, ":")[1], ",")
-                       typeCodes := make([]arrow.UnionTypeCode, 0, len(codes))
-                       for _, i := range codes {
+                       _, val, ok := strings.Cut(f, ":")
+                       if !ok {
+                               return ret, fmt.Errorf("invalid union type code 
spec %q", f)
+                       }
+                       var typeCodes []arrow.UnionTypeCode
+                       for i := range strings.SplitSeq(val, ",") {
                                v, e := strconv.ParseInt(i, 10, 8)
                                if e != nil {
                                        err = fmt.Errorf("%w: invalid type 
code: %s", arrow.ErrInvalid, e)
diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go
index 8fa690f2..c196b215 100644
--- a/arrow/cdata/cdata_test.go
+++ b/arrow/cdata/cdata_test.go
@@ -174,6 +174,24 @@ func TestDecimalSchemaErrors(t *testing.T) {
        }
 }
 
+func TestUnionSchemaErrors(t *testing.T) {
+       tests := []struct {
+               fmt string
+       }{
+               {"+us"}, // missing ":<type_codes>"
+               {"+ud"}, // missing ":<type_codes>"
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.fmt, func(t *testing.T) {
+                       sc := testPrimitive(tt.fmt)
+
+                       _, err := ImportCArrowField(&sc)
+                       assert.Error(t, err)
+               })
+       }
+}
+
 func TestImportTemporalSchema(t *testing.T) {
        tests := []struct {
                typ arrow.DataType
@@ -195,9 +213,12 @@ func TestImportTemporalSchema(t *testing.T) {
                {arrow.FixedWidthTypes.Timestamp_s, "tss:UTC"},
                {&arrow.TimestampType{Unit: arrow.Second}, "tss:"},
                {&arrow.TimestampType{Unit: arrow.Second, TimeZone: 
"Europe/Paris"}, "tss:Europe/Paris"},
+               {&arrow.TimestampType{Unit: arrow.Second, TimeZone: 
"Etc/GMT+1"}, "tss:Etc/GMT+1"},
+               {&arrow.TimestampType{Unit: arrow.Second, TimeZone: "+01:00"}, 
"tss:+01:00"},
                {arrow.FixedWidthTypes.Timestamp_ms, "tsm:UTC"},
                {&arrow.TimestampType{Unit: arrow.Millisecond}, "tsm:"},
                {&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: 
"Europe/Paris"}, "tsm:Europe/Paris"},
+               {&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: 
"-07:30"}, "tsm:-07:30"},
                {arrow.FixedWidthTypes.Timestamp_us, "tsu:UTC"},
                {&arrow.TimestampType{Unit: arrow.Microsecond}, "tsu:"},
                {&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: 
"Europe/Paris"}, "tsu:Europe/Paris"},
@@ -207,7 +228,7 @@ func TestImportTemporalSchema(t *testing.T) {
        }
 
        for _, tt := range tests {
-               t.Run(tt.typ.Name(), func(t *testing.T) {
+               t.Run(tt.fmt, func(t *testing.T) {
                        sc := testPrimitive(tt.fmt)
 
                        f, err := ImportCArrowField(&sc)
diff --git a/arrow/flight/basic_auth_flight_test.go 
b/arrow/flight/basic_auth_flight_test.go
index 849b25d8..744ea29f 100644
--- a/arrow/flight/basic_auth_flight_test.go
+++ b/arrow/flight/basic_auth_flight_test.go
@@ -206,3 +206,46 @@ func TestBasicAuthHelpers(t *testing.T) {
                t.Fatal("should have received carebears")
        }
 }
+
+func TestBasicAuthMissingCredential(t *testing.T) {
+       s := 
flight.NewServerWithMiddleware([]flight.ServerMiddleware{flight.CreateServerBasicAuthMiddleware(&validator{})})
+       s.Init("localhost:0")
+       f := &HeaderAuthTestFlight{}
+       s.RegisterFlightService(f)
+       go s.Serve()
+       defer s.Shutdown()
+
+       client, err := flight.NewFlightClient(s.Addr().String(), nil, 
grpc.WithTransportCredentials(insecure.NewCredentials()))
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       ctx := metadata.NewOutgoingContext(context.Background(), 
metadata.New(map[string]string{
+               "authorization": "Basic",
+       }))
+
+       fc, err := client.Handshake(ctx)
+       if err != nil {
+               st, ok := status.FromError(err)
+               if !ok {
+                       t.Fatalf("expected gRPC status error, got %T: %v", err, 
err)
+               }
+               if got, want := st.Code(), codes.Unauthenticated; got != want {
+                       t.Fatalf("unexpected code: got %v, want %v", got, want)
+               }
+               return
+       }
+
+       _, err = fc.Recv()
+       if err == nil {
+               t.Fatal("expected error")
+       }
+
+       st, ok := status.FromError(err)
+       if !ok {
+               t.Fatalf("expected gRPC status error, got %T: %v", err, err)
+       }
+       if got, want := st.Code(), codes.Unauthenticated; got != want {
+               t.Fatalf("unexpected code: got %v, want %v", got, want)
+       }
+}
diff --git a/arrow/flight/server_auth.go b/arrow/flight/server_auth.go
index cc78d85a..7135e43a 100644
--- a/arrow/flight/server_auth.go
+++ b/arrow/flight/server_auth.go
@@ -170,24 +170,26 @@ func createServerBearerTokenUnaryInterceptor(validator 
BasicAuthValidator) grpc.
 
 func createServerBearerTokenStreamInterceptor(validator BasicAuthValidator) 
grpc.StreamServerInterceptor {
        return func(srv interface{}, stream grpc.ServerStream, info 
*grpc.StreamServerInfo, handler grpc.StreamHandler) error {
-               var auth []string
+               var scheme, credential string
                md, ok := metadata.FromIncomingContext(stream.Context())
                if ok {
-                       auth = md.Get(basicAuthHeader)
+                       auth := md.Get(basicAuthHeader)
                        if len(auth) > 0 {
-                               auth = strings.Split(auth[0], " ")
+                               s := strings.TrimSpace(auth[0])
+                               scheme, credential, _ = strings.Cut(s, " ")
+                               credential = strings.TrimLeft(credential, " ") 
// only trim SP per HTTP auth format, keep trailing spaces.
                        }
                }
 
-               if len(auth) == 0 {
+               if scheme == "" || credential == "" {
                        return status.Error(codes.Unauthenticated, "must 
authenticate first")
                }
 
                if strings.HasSuffix(info.FullMethod, "/Handshake") {
-                       if auth[0] == basicAuthPrefix {
-                               val, err := 
base64.RawStdEncoding.DecodeString(auth[1])
+                       if scheme == basicAuthPrefix {
+                               val, err := 
base64.RawStdEncoding.DecodeString(credential)
                                if err != nil {
-                                       val, err = 
base64.StdEncoding.DecodeString(auth[1])
+                                       val, err = 
base64.StdEncoding.DecodeString(credential)
                                        if err != nil {
                                                return 
status.Errorf(codes.Unauthenticated, "invalid basic auth encoding: %s", err)
                                        }
@@ -199,14 +201,16 @@ func createServerBearerTokenStreamInterceptor(validator 
BasicAuthValidator) grpc
                                        return err
                                }
 
-                               
stream.SetTrailer(metadata.New(map[string]string{basicAuthHeader: 
strings.Join([]string{bearerTokenPrefix, token}, " ")}))
+                               
stream.SetTrailer(metadata.New(map[string]string{
+                                       basicAuthHeader: bearerTokenPrefix + " 
" + token,
+                               }))
                                return handler(srv, stream)
                        }
                        return status.Errorf(codes.Unauthenticated, "only Basic 
Auth implemented")
                }
 
-               if auth[0] == bearerTokenPrefix {
-                       identity, err := validator.IsValid(auth[1])
+               if scheme == bearerTokenPrefix {
+                       identity, err := validator.IsValid(credential)
                        if err != nil {
                                return err
                        }

Reply via email to