This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new 6648c1d2 fix(arrow/cdata, arrow/flight): fix handling of colons in
values and fix potential panics (#761)
6648c1d2 is described below
commit 6648c1d202bd3d844d532b9167ad292b093a05e6
Author: Sebastiaan van Stijn <[email protected]>
AuthorDate: Mon Apr 13 03:10:40 2026 +0200
fix(arrow/cdata, arrow/flight): fix handling of colons in values and fix
potential panics (#761)
### fix(arrow/cdata): importSchema: handle colons in values
Use strings.Cut, both as an optimization, and to prevent values
containing
a colon (e.g. "tsu:+01:00") from being mis-interpreted. This patch also
removes some intermediate variables, and redundant handling of
"defaulttz",
which assigned an empty string if the value was empty.
### fix(arrow/cdata): importSchema: fix potential panic and optimize
Rewrite the code with strings.Cut and strings.SplitSeq to reduce
allocations, and to fix a potential panic.
Before this patch, the code would panic if a colon was missing;
CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors
./arrow/cdata/
--- FAIL: TestUnionSchemaErrors (0.00s)
--- FAIL: TestUnionSchemaErrors/+us (0.00s)
panic: runtime error: index out of range [1] with length 1 [recovered,
repanicked]
goroutine 9 [running]:
testing.tRunner.func1.2({0x7fc7c0, 0x4000026ab0})
/usr/local/go/src/testing/testing.go:1872 +0x190
testing.tRunner.func1()
/usr/local/go/src/testing/testing.go:1875 +0x31c
panic({0x7fc7c0?, 0x4000026ab0?})
/usr/local/go/src/runtime/panic.go:783 +0x120
github.com/apache/arrow-go/v18/arrow/cdata.importSchema(0x40001c36d0)
/foo/arrow/cdata/cdata.go:306 +0x1520
github.com/apache/arrow-go/v18/arrow/cdata.ImportCArrowField(...)
/foo/arrow/cdata/interface.go:43
github.com/apache/arrow-go/v18/arrow/cdata.TestUnionSchemaErrors.func1(0x40000e0a80)
/foo/arrow/cdata/cdata_test.go:188 +0xb0
testing.tRunner(0x40000e0a80, 0x400020c060)
/usr/local/go/src/testing/testing.go:1934 +0xc8
created by testing.(*T).Run in goroutine 8
/usr/local/go/src/testing/testing.go:1997 +0x364
FAIL github.com/apache/arrow-go/v18/arrow/cdata 0.007s
FAIL
With this patch applied, the code handles the invalid value gracefully;
CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors
./arrow/cdata/
=== RUN TestUnionSchemaErrors
=== RUN TestUnionSchemaErrors/+us
=== RUN TestUnionSchemaErrors/+ud
--- PASS: TestUnionSchemaErrors (0.00s)
--- PASS: TestUnionSchemaErrors/+us (0.00s)
--- PASS: TestUnionSchemaErrors/+ud (0.00s)
PASS
ok github.com/apache/arrow-go/v18/arrow/cdata 0.003s
### fix(arrow/flight): avoid panic on malformed authorization header
Rewrite the code with strings.Cut for readability and ensue missing
credentials
in Basic/Bearer authorization headers return Unauthenticated instead of
panicking.
Before this patch, the code could panic;
go test -run TestBasicAuthMissingCredential ./arrow/flight/
panic: runtime error: index out of range [1] with length 1
goroutine 7 [running]:
github.com/apache/arrow-go/v18/arrow/flight_test.TestBasicAuthMissingCredential.CreateServerBasicAuthMiddleware.createServerBearerTokenStreamInterceptor.func3({0x8d8240,
0x40002134a0}, {0xa73e68, 0x40000e2000}, 0x40000100c0, 0x96b628)
/foo/arrow/flight/server_auth.go:188 +0x49c
....
With this patch applied, the code handles the invalid header gracefully;
go test -run TestBasicAuthMissingCredential ./arrow/flight/
ok github.com/apache/arrow-go/v18/arrow/flight 0.010s
### Rationale for this change
### What changes are included in this PR?
### Are these changes tested?
### Are there any user-facing changes?
---------
Signed-off-by: Sebastiaan van Stijn <[email protected]>
---
arrow/cdata/cdata.go | 41 ++++++++++----------------------
arrow/cdata/cdata_test.go | 23 +++++++++++++++++-
arrow/flight/basic_auth_flight_test.go | 43 ++++++++++++++++++++++++++++++++++
arrow/flight/server_auth.go | 24 +++++++++++--------
4 files changed, 92 insertions(+), 39 deletions(-)
diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go
index 63419469..d8eeb5b2 100644
--- a/arrow/cdata/cdata.go
+++ b/arrow/cdata/cdata.go
@@ -202,41 +202,23 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field,
err error) {
}
// handle types with params via colon
- typs := strings.Split(f, ":")
- defaulttz := ""
- switch typs[0] {
+ switch key, val, _ := strings.Cut(f, ":"); key {
case "tss":
- tz := typs[1]
- if len(typs[1]) == 0 {
- tz = defaulttz
- }
- dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: tz}
+ dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: val}
case "tsm":
- tz := typs[1]
- if len(typs[1]) == 0 {
- tz = defaulttz
- }
- dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: tz}
+ dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone:
val}
case "tsu":
- tz := typs[1]
- if len(typs[1]) == 0 {
- tz = defaulttz
- }
- dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: tz}
+ dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone:
val}
case "tsn":
- tz := typs[1]
- if len(typs[1]) == 0 {
- tz = defaulttz
- }
- dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: tz}
+ dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: val}
case "w": // fixed size binary is "w:##" where ## is the byteWidth
- byteWidth, err := strconv.Atoi(typs[1])
+ byteWidth, err := strconv.Atoi(val)
if err != nil {
return ret, err
}
dt = &arrow.FixedSizeBinaryType{ByteWidth: byteWidth}
case "d": // decimal types are d:<precision>,<scale>[,<bitsize>] size
is assumed 128 if left out
- props := typs[1]
+ props := val
propList := strings.Split(props, ",")
bitwidth := 128
var precision, scale int
@@ -317,9 +299,12 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field,
err error) {
return
}
- codes := strings.Split(strings.Split(f, ":")[1], ",")
- typeCodes := make([]arrow.UnionTypeCode, 0, len(codes))
- for _, i := range codes {
+ _, val, ok := strings.Cut(f, ":")
+ if !ok {
+ return ret, fmt.Errorf("invalid union type code
spec %q", f)
+ }
+ var typeCodes []arrow.UnionTypeCode
+ for i := range strings.SplitSeq(val, ",") {
v, e := strconv.ParseInt(i, 10, 8)
if e != nil {
err = fmt.Errorf("%w: invalid type
code: %s", arrow.ErrInvalid, e)
diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go
index 8fa690f2..c196b215 100644
--- a/arrow/cdata/cdata_test.go
+++ b/arrow/cdata/cdata_test.go
@@ -174,6 +174,24 @@ func TestDecimalSchemaErrors(t *testing.T) {
}
}
+func TestUnionSchemaErrors(t *testing.T) {
+ tests := []struct {
+ fmt string
+ }{
+ {"+us"}, // missing ":<type_codes>"
+ {"+ud"}, // missing ":<type_codes>"
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.fmt, func(t *testing.T) {
+ sc := testPrimitive(tt.fmt)
+
+ _, err := ImportCArrowField(&sc)
+ assert.Error(t, err)
+ })
+ }
+}
+
func TestImportTemporalSchema(t *testing.T) {
tests := []struct {
typ arrow.DataType
@@ -195,9 +213,12 @@ func TestImportTemporalSchema(t *testing.T) {
{arrow.FixedWidthTypes.Timestamp_s, "tss:UTC"},
{&arrow.TimestampType{Unit: arrow.Second}, "tss:"},
{&arrow.TimestampType{Unit: arrow.Second, TimeZone:
"Europe/Paris"}, "tss:Europe/Paris"},
+ {&arrow.TimestampType{Unit: arrow.Second, TimeZone:
"Etc/GMT+1"}, "tss:Etc/GMT+1"},
+ {&arrow.TimestampType{Unit: arrow.Second, TimeZone: "+01:00"},
"tss:+01:00"},
{arrow.FixedWidthTypes.Timestamp_ms, "tsm:UTC"},
{&arrow.TimestampType{Unit: arrow.Millisecond}, "tsm:"},
{&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone:
"Europe/Paris"}, "tsm:Europe/Paris"},
+ {&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone:
"-07:30"}, "tsm:-07:30"},
{arrow.FixedWidthTypes.Timestamp_us, "tsu:UTC"},
{&arrow.TimestampType{Unit: arrow.Microsecond}, "tsu:"},
{&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone:
"Europe/Paris"}, "tsu:Europe/Paris"},
@@ -207,7 +228,7 @@ func TestImportTemporalSchema(t *testing.T) {
}
for _, tt := range tests {
- t.Run(tt.typ.Name(), func(t *testing.T) {
+ t.Run(tt.fmt, func(t *testing.T) {
sc := testPrimitive(tt.fmt)
f, err := ImportCArrowField(&sc)
diff --git a/arrow/flight/basic_auth_flight_test.go
b/arrow/flight/basic_auth_flight_test.go
index 849b25d8..744ea29f 100644
--- a/arrow/flight/basic_auth_flight_test.go
+++ b/arrow/flight/basic_auth_flight_test.go
@@ -206,3 +206,46 @@ func TestBasicAuthHelpers(t *testing.T) {
t.Fatal("should have received carebears")
}
}
+
+func TestBasicAuthMissingCredential(t *testing.T) {
+ s :=
flight.NewServerWithMiddleware([]flight.ServerMiddleware{flight.CreateServerBasicAuthMiddleware(&validator{})})
+ s.Init("localhost:0")
+ f := &HeaderAuthTestFlight{}
+ s.RegisterFlightService(f)
+ go s.Serve()
+ defer s.Shutdown()
+
+ client, err := flight.NewFlightClient(s.Addr().String(), nil,
grpc.WithTransportCredentials(insecure.NewCredentials()))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ctx := metadata.NewOutgoingContext(context.Background(),
metadata.New(map[string]string{
+ "authorization": "Basic",
+ }))
+
+ fc, err := client.Handshake(ctx)
+ if err != nil {
+ st, ok := status.FromError(err)
+ if !ok {
+ t.Fatalf("expected gRPC status error, got %T: %v", err,
err)
+ }
+ if got, want := st.Code(), codes.Unauthenticated; got != want {
+ t.Fatalf("unexpected code: got %v, want %v", got, want)
+ }
+ return
+ }
+
+ _, err = fc.Recv()
+ if err == nil {
+ t.Fatal("expected error")
+ }
+
+ st, ok := status.FromError(err)
+ if !ok {
+ t.Fatalf("expected gRPC status error, got %T: %v", err, err)
+ }
+ if got, want := st.Code(), codes.Unauthenticated; got != want {
+ t.Fatalf("unexpected code: got %v, want %v", got, want)
+ }
+}
diff --git a/arrow/flight/server_auth.go b/arrow/flight/server_auth.go
index cc78d85a..7135e43a 100644
--- a/arrow/flight/server_auth.go
+++ b/arrow/flight/server_auth.go
@@ -170,24 +170,26 @@ func createServerBearerTokenUnaryInterceptor(validator
BasicAuthValidator) grpc.
func createServerBearerTokenStreamInterceptor(validator BasicAuthValidator)
grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info
*grpc.StreamServerInfo, handler grpc.StreamHandler) error {
- var auth []string
+ var scheme, credential string
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
- auth = md.Get(basicAuthHeader)
+ auth := md.Get(basicAuthHeader)
if len(auth) > 0 {
- auth = strings.Split(auth[0], " ")
+ s := strings.TrimSpace(auth[0])
+ scheme, credential, _ = strings.Cut(s, " ")
+ credential = strings.TrimLeft(credential, " ")
// only trim SP per HTTP auth format, keep trailing spaces.
}
}
- if len(auth) == 0 {
+ if scheme == "" || credential == "" {
return status.Error(codes.Unauthenticated, "must
authenticate first")
}
if strings.HasSuffix(info.FullMethod, "/Handshake") {
- if auth[0] == basicAuthPrefix {
- val, err :=
base64.RawStdEncoding.DecodeString(auth[1])
+ if scheme == basicAuthPrefix {
+ val, err :=
base64.RawStdEncoding.DecodeString(credential)
if err != nil {
- val, err =
base64.StdEncoding.DecodeString(auth[1])
+ val, err =
base64.StdEncoding.DecodeString(credential)
if err != nil {
return
status.Errorf(codes.Unauthenticated, "invalid basic auth encoding: %s", err)
}
@@ -199,14 +201,16 @@ func createServerBearerTokenStreamInterceptor(validator
BasicAuthValidator) grpc
return err
}
-
stream.SetTrailer(metadata.New(map[string]string{basicAuthHeader:
strings.Join([]string{bearerTokenPrefix, token}, " ")}))
+
stream.SetTrailer(metadata.New(map[string]string{
+ basicAuthHeader: bearerTokenPrefix + "
" + token,
+ }))
return handler(srv, stream)
}
return status.Errorf(codes.Unauthenticated, "only Basic
Auth implemented")
}
- if auth[0] == bearerTokenPrefix {
- identity, err := validator.IsValid(auth[1])
+ if scheme == bearerTokenPrefix {
+ identity, err := validator.IsValid(credential)
if err != nil {
return err
}