This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new cdc768ed fix(arrow/flight): deliver response headers eagerly in
streaming client middleware (#801)
cdc768ed is described below
commit cdc768edeb2f0a4017c0d27efc27efbb1f278214
Author: Matt Topol <[email protected]>
AuthorDate: Wed May 6 15:26:48 2026 -0400
fix(arrow/flight): deliver response headers eagerly in streaming client
middleware (#801)
### Rationale for this change
Fixes #755. The cookie middleware (`NewClientCookieMiddleware`) does not
capture `Set-Cookie` headers returned in response to a streaming RPC
like `Handshake` when the server also sends back a response payload.
`ClientHeadersMiddleware.HeadersReceived` was only invoked from
`finishFn`, which fires when `Recv()` returns `io.EOF` or a non-`io.EOF`
error. `AuthenticateBasicToken` calls `Recv()` exactly once; if the
server sends a `HandshakeResponse` payload (common when the Handshake
carries auth data or a session cookie), `Recv()` returns that message
rather than `io.EOF` and `finishFn` never fires. The cookie middleware
never sees the response headers, so the session cookie is dropped and
subsequent RPCs go out without it, even though the user reports cookies
ARE delivered on other endpoints like `GetFlightInfo` (unary RPCs
capture headers synchronously via `grpc.Header(&md)`).
### What changes are included in this PR?
- `clientStream.Header()` now delivers response metadata to
`ClientHeadersMiddleware` at-most-once (guarded by
`atomic.Bool.CompareAndSwap`) the first time headers are successfully
retrieved for a streaming RPC.
- The existing `finishFn` path is unchanged so:
- trailers are still captured when the stream completes, and
- callers that never explicitly invoke `Header()` get the exact same
behavior as before.
- Added four regression tests in `arrow/flight/handshake_cookie_test.go`
covering:
1. `Set-Cookie` in Handshake response **headers** (via
`AuthenticateBasicToken`)
2. `Set-Cookie` in Handshake response **trailers**
3. `Set-Cookie` + server-sent `HandshakeResponse` payload (the precise
scenario reported in #755 — fails without this fix)
4. Eager capture when `stream.Header()` is inspected before draining the
stream (also fails without this fix)
### Are these changes tested?
Yes. The four new tests in `arrow/flight/handshake_cookie_test.go`
reproduce the regression. Tests 3 and 4 fail without the fix and pass
with it. The existing middleware/cookie tests continue to pass,
including with `-race`.
### Are there any user-facing changes?
Minor behavioral refinement of `ClientHeadersMiddleware` for streaming
RPCs: `HeadersReceived` may now be invoked up to twice on a streaming
RPC whose caller explicitly calls `stream.Header()` — once with just the
response headers (from `Header()`), and again with headers+trailers
joined (from the existing `finishFn` path at stream completion). This is
backward compatible for the in-tree `clientCookieMiddleware` (cookie
updates are keyed by `name+path` and idempotent). Callers that never
explicitly call `stream.Header()` see no change in behavior.
---
arrow/flight/client.go | 25 +++
arrow/flight/flightsql/driver/driver_test.go | 5 +
arrow/flight/handshake_cookie_test.go | 324 +++++++++++++++++++++++++++
3 files changed, 354 insertions(+)
diff --git a/arrow/flight/client.go b/arrow/flight/client.go
index 96eb0e6b..92699a59 100644
--- a/arrow/flight/client.go
+++ b/arrow/flight/client.go
@@ -175,6 +175,19 @@ func CreateClientMiddleware(middleware
CustomClientMiddleware) ClientMiddleware
desc: desc,
finishFn: finishFunc,
}
+ if isHdrs {
+ // Deliver response headers to the middleware
as soon as they
+ // are first retrieved via Header(), rather
than waiting for
+ // the stream to finish. This is necessary for
streaming RPCs
+ // like Handshake where the caller may inspect
headers (e.g.
+ // Set-Cookie) and issue subsequent RPCs before
the stream
+ // reaches io.EOF (e.g. when the server sends a
response
+ // payload that causes Recv to return a message
instead of
+ // EOF). See GH-755.
+ newCS.onHeaders = func(md metadata.MD) {
+ hdrs.HeadersReceived(csCtx, md)
+ }
+ }
// The `ClientStream` interface allows one to omit
calling `Recv` if it's
// known that the result will be `io.EOF`. See
// http://stackoverflow.com/q/42915337
@@ -193,12 +206,24 @@ type clientStream struct {
grpc.ClientStream
desc *grpc.StreamDesc
finishFn func(error)
+
+ // onHeaders, when non-nil, is invoked at most once with the response
+ // metadata the first time Header() returns successfully. It allows
+ // middleware (e.g. cookie middleware) to observe server headers as
+ // soon as they arrive on streaming RPCs, rather than waiting for the
+ // stream to finish via finishFn. See GH-755.
+ onHeaders func(md metadata.MD)
+ headersObserved atomic.Bool
}
func (cs *clientStream) Header() (metadata.MD, error) {
md, err := cs.ClientStream.Header()
if err != nil {
cs.finishFn(err)
+ return md, err
+ }
+ if cs.onHeaders != nil && cs.headersObserved.CompareAndSwap(false,
true) {
+ cs.onHeaders(md)
}
return md, err
}
diff --git a/arrow/flight/flightsql/driver/driver_test.go
b/arrow/flight/flightsql/driver/driver_test.go
index 39d9dfd9..82c66753 100644
--- a/arrow/flight/flightsql/driver/driver_test.go
+++ b/arrow/flight/flightsql/driver/driver_test.go
@@ -1819,6 +1819,11 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx
context.Context, qry flight
if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
return nil, errors.New("parameter schema: unexpected")
}
+ // See GH-35328: drain remaining batches before returning to
avoid
+ // the io.EOF race between server close and client Write. The
other
+ // success path below already does this; this branch must too.
+ for r.Next() {
+ }
return qry.GetPreparedStatementHandle(), nil
}
diff --git a/arrow/flight/handshake_cookie_test.go
b/arrow/flight/handshake_cookie_test.go
new file mode 100644
index 00000000..c5c8464e
--- /dev/null
+++ b/arrow/flight/handshake_cookie_test.go
@@ -0,0 +1,324 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight_test
+
+import (
+ "context"
+ "encoding/base64"
+ "errors"
+ "io"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/apache/arrow-go/v18/arrow/flight"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/metadata"
+)
+
+// handshakeCookieFlightServer is a flight server that emits Set-Cookie
+// response headers (and trailers) during Handshake, simulating a server
+// that creates a session during the authentication flow (see GH-755).
+type handshakeCookieFlightServer struct {
+ flight.BaseFlightServer
+
+ headerCookie string // cookie attached via SendHeader during
Handshake
+ trailerCookie string // cookie attached via SetTrailer during
Handshake
+ bearerToken string // authorization header returned during
Handshake
+ sendPayload bool // if true, server sends a HandshakeResponse
payload before closing
+ mu sync.Mutex
+ lastIncomingCook []string // incoming Cookie header values observed on
ListFlights
+}
+
+func (h *handshakeCookieFlightServer) Handshake(stream
flight.FlightService_HandshakeServer) error {
+ md := metadata.MD{}
+ if h.headerCookie != "" {
+ md.Append("set-cookie", h.headerCookie)
+ }
+ if h.bearerToken != "" {
+ md.Append("authorization", "Bearer "+h.bearerToken)
+ }
+ if len(md) > 0 {
+ if err := stream.SendHeader(md); err != nil {
+ return err
+ }
+ }
+
+ if h.trailerCookie != "" {
+ stream.SetTrailer(metadata.Pairs("set-cookie", h.trailerCookie))
+ }
+
+ if h.sendPayload {
+ if err := stream.Send(&flight.HandshakeResponse{Payload:
[]byte("handshake-ok")}); err != nil {
+ return err
+ }
+ }
+
+ // Drain the client stream until it closes.
+ for {
+ if _, err := stream.Recv(); err != nil {
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+ return err
+ }
+ }
+}
+
+func (h *handshakeCookieFlightServer) ListFlights(c *flight.Criteria, fs
flight.FlightService_ListFlightsServer) error {
+ h.mu.Lock()
+ if md, ok := metadata.FromIncomingContext(fs.Context()); ok {
+ h.lastIncomingCook = append([]string(nil), md.Get("cookie")...)
+ } else {
+ h.lastIncomingCook = nil
+ }
+ h.mu.Unlock()
+ return nil
+}
+
+func (h *handshakeCookieFlightServer) observedCookies() []string {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ return append([]string(nil), h.lastIncomingCook...)
+}
+
+// TestHandshakeCookiePropagationViaAuthenticateBasicToken is a regression
+// test for GH-755. It asserts that Set-Cookie headers returned by a
+// Handshake/DoHandshake response are captured by the cookie middleware
+// and attached to subsequent requests.
+func TestHandshakeCookiePropagationViaAuthenticateBasicToken(t *testing.T) {
+ srv := &handshakeCookieFlightServer{
+ headerCookie: "session_id=sess_header_abc",
+ bearerToken: "my-bearer-token",
+ }
+
+ s := flight.NewServerWithMiddleware(nil)
+ s.Init("localhost:0")
+ s.RegisterFlightService(srv)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+ client, err := flight.NewClientWithMiddleware(
+ s.Addr().String(),
+ nil,
+ []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+ creds,
+ )
+ require.NoError(t, err)
+ defer client.Close()
+
+ ctx, err := client.AuthenticateBasicToken(context.Background(), "user",
"pass")
+ require.NoError(t, err)
+
+ // Make a follow-up RPC. The cookie middleware must have captured
+ // Set-Cookie from the Handshake response, and StartCall should
+ // attach it as a Cookie header on this call.
+ stream, err := client.ListFlights(ctx, &flight.Criteria{})
+ require.NoError(t, err)
+ for {
+ if _, err := stream.Recv(); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ }
+ }
+
+ cookies := srv.observedCookies()
+ require.Len(t, cookies, 1, "expected exactly one Cookie header, got
%v", cookies)
+ assert.Contains(t, cookies[0], "session_id=sess_header_abc",
+ "cookie middleware should propagate Set-Cookie from Handshake
response headers")
+}
+
+// TestHandshakeCookiePropagationFromTrailers ensures cookies delivered as
+// gRPC trailers (instead of initial metadata headers) are also captured
+// by the cookie middleware during Handshake.
+func TestHandshakeCookiePropagationFromTrailers(t *testing.T) {
+ srv := &handshakeCookieFlightServer{
+ trailerCookie: "session_id=sess_trailer_xyz",
+ bearerToken: "my-bearer-token",
+ }
+
+ s := flight.NewServerWithMiddleware(nil)
+ s.Init("localhost:0")
+ s.RegisterFlightService(srv)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+ client, err := flight.NewClientWithMiddleware(
+ s.Addr().String(),
+ nil,
+ []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+ creds,
+ )
+ require.NoError(t, err)
+ defer client.Close()
+
+ ctx, err := client.AuthenticateBasicToken(context.Background(), "user",
"pass")
+ require.NoError(t, err)
+
+ stream, err := client.ListFlights(ctx, &flight.Criteria{})
+ require.NoError(t, err)
+ for {
+ if _, err := stream.Recv(); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ }
+ }
+
+ cookies := srv.observedCookies()
+ require.Len(t, cookies, 1, "expected exactly one Cookie header, got
%v", cookies)
+ assert.Contains(t, cookies[0], "session_id=sess_trailer_xyz",
+ "cookie middleware should propagate Set-Cookie from Handshake
response trailers")
+}
+
+// TestHandshakeCookiePropagationWithServerPayload is the precise scenario
+// reported in GH-755. The server attaches a Set-Cookie header AND sends
+// back a HandshakeResponse payload. AuthenticateBasicToken only calls
+// stream.Recv() once, which returns the payload (not io.EOF), so the
+// streaming finishFn that would normally invoke HeadersReceived never
+// fires. The cookie middleware must still capture the header cookie.
+func TestHandshakeCookiePropagationWithServerPayload(t *testing.T) {
+ srv := &handshakeCookieFlightServer{
+ headerCookie: "session_id=sess_with_payload",
+ bearerToken: "my-bearer-token",
+ sendPayload: true,
+ }
+
+ s := flight.NewServerWithMiddleware(nil)
+ s.Init("localhost:0")
+ s.RegisterFlightService(srv)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+ client, err := flight.NewClientWithMiddleware(
+ s.Addr().String(),
+ nil,
+ []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+ creds,
+ )
+ require.NoError(t, err)
+ defer client.Close()
+
+ ctx, err := client.AuthenticateBasicToken(context.Background(), "user",
"pass")
+ require.NoError(t, err)
+
+ stream, err := client.ListFlights(ctx, &flight.Criteria{})
+ require.NoError(t, err)
+ for {
+ if _, err := stream.Recv(); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ }
+ }
+
+ cookies := srv.observedCookies()
+ require.Len(t, cookies, 1,
+ "expected exactly one Cookie header, got %v (GH-755: cookie
lost when Handshake returns a payload)", cookies)
+ assert.Contains(t, cookies[0], "session_id=sess_with_payload")
+}
+
+// TestHandshakeCookieProcessedBeforeRecv verifies cookies are captured
+// eagerly once stream.Header() returns successfully. This models the
+// scenario where an application-level Handshake flow inspects response
+// headers and makes further RPCs before draining the stream.
+func TestHandshakeCookieProcessedBeforeRecv(t *testing.T) {
+ srv := &handshakeCookieFlightServer{
+ headerCookie: "session_id=eager_capture",
+ }
+
+ s := flight.NewServerWithMiddleware(nil)
+ s.Init("localhost:0")
+ s.RegisterFlightService(srv)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ cookies := flight.NewCookieMiddleware()
+ creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+ client, err := flight.NewClientWithMiddleware(
+ s.Addr().String(),
+ nil,
+
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cookies)},
+ creds,
+ )
+ require.NoError(t, err)
+ defer client.Close()
+
+ // Drive the Handshake manually; inspect headers before calling Recv().
+ authCtx := metadata.AppendToOutgoingContext(context.Background(),
+ "Authorization", "Basic
"+base64.RawStdEncoding.EncodeToString([]byte("user:pass")))
+
+ stream, err := client.Handshake(authCtx)
+ require.NoError(t, err)
+ require.NoError(t, stream.CloseSend())
+
+ hdr, err := stream.Header()
+ require.NoError(t, err)
+ require.Contains(t, strings.Join(hdr.Get("set-cookie"), ","),
"eager_capture")
+
+ // Clone the middleware while the original Handshake stream is still
+ // open. If cookies were processed eagerly from the header, the clone
+ // should already contain the session cookie.
+ cloned := cookies.Clone()
+
+ // Using the clone, make a unary-ish request against a second client
+ // to observe the outgoing Cookie header.
+ clientB, err := flight.NewClientWithMiddleware(
+ s.Addr().String(),
+ nil,
+
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cloned)},
+ creds,
+ )
+ require.NoError(t, err)
+ defer clientB.Close()
+
+ ls, err := clientB.ListFlights(context.Background(), &flight.Criteria{})
+ require.NoError(t, err)
+ for {
+ if _, err := ls.Recv(); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ }
+ }
+
+ got := srv.observedCookies()
+ require.Len(t, got, 1, "expected cloned middleware to send cookie from
eagerly captured Handshake header, got %v", got)
+ assert.Contains(t, got[0], "session_id=eager_capture")
+
+ // Clean up original stream.
+ for {
+ if _, err := stream.Recv(); err != nil {
+ break
+ }
+ }
+}