This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new cdc768ed fix(arrow/flight): deliver response headers eagerly in 
streaming client middleware (#801)
cdc768ed is described below

commit cdc768edeb2f0a4017c0d27efc27efbb1f278214
Author: Matt Topol <[email protected]>
AuthorDate: Wed May 6 15:26:48 2026 -0400

    fix(arrow/flight): deliver response headers eagerly in streaming client 
middleware (#801)
    
    ### Rationale for this change
    
    Fixes #755. The cookie middleware (`NewClientCookieMiddleware`) does not
    capture `Set-Cookie` headers returned in response to a streaming RPC
    like `Handshake` when the server also sends back a response payload.
    
    `ClientHeadersMiddleware.HeadersReceived` was only invoked from
    `finishFn`, which fires when `Recv()` returns `io.EOF` or a non-`io.EOF`
    error. `AuthenticateBasicToken` calls `Recv()` exactly once; if the
    server sends a `HandshakeResponse` payload (common when the Handshake
    carries auth data or a session cookie), `Recv()` returns that message
    rather than `io.EOF` and `finishFn` never fires. The cookie middleware
    never sees the response headers, so the session cookie is dropped and
    subsequent RPCs go out without it, even though the user reports cookies
    ARE delivered on other endpoints like `GetFlightInfo` (unary RPCs
    capture headers synchronously via `grpc.Header(&md)`).
    
    ### What changes are included in this PR?
    
    - `clientStream.Header()` now delivers response metadata to
    `ClientHeadersMiddleware` at-most-once (guarded by
    `atomic.Bool.CompareAndSwap`) the first time headers are successfully
    retrieved for a streaming RPC.
    - The existing `finishFn` path is unchanged so:
      - trailers are still captured when the stream completes, and
    - callers that never explicitly invoke `Header()` get the exact same
    behavior as before.
    - Added four regression tests in `arrow/flight/handshake_cookie_test.go`
    covering:
    1. `Set-Cookie` in Handshake response **headers** (via
    `AuthenticateBasicToken`)
      2. `Set-Cookie` in Handshake response **trailers**
    3. `Set-Cookie` + server-sent `HandshakeResponse` payload (the precise
    scenario reported in #755 — fails without this fix)
    4. Eager capture when `stream.Header()` is inspected before draining the
    stream (also fails without this fix)
    
    ### Are these changes tested?
    
    Yes. The four new tests in `arrow/flight/handshake_cookie_test.go`
    reproduce the regression. Tests 3 and 4 fail without the fix and pass
    with it. The existing middleware/cookie tests continue to pass,
    including with `-race`.
    
    ### Are there any user-facing changes?
    
    Minor behavioral refinement of `ClientHeadersMiddleware` for streaming
    RPCs: `HeadersReceived` may now be invoked up to twice on a streaming
    RPC whose caller explicitly calls `stream.Header()` — once with just the
    response headers (from `Header()`), and again with headers+trailers
    joined (from the existing `finishFn` path at stream completion). This is
    backward compatible for the in-tree `clientCookieMiddleware` (cookie
    updates are keyed by `name+path` and idempotent). Callers that never
    explicitly call `stream.Header()` see no change in behavior.
---
 arrow/flight/client.go                       |  25 +++
 arrow/flight/flightsql/driver/driver_test.go |   5 +
 arrow/flight/handshake_cookie_test.go        | 324 +++++++++++++++++++++++++++
 3 files changed, 354 insertions(+)

diff --git a/arrow/flight/client.go b/arrow/flight/client.go
index 96eb0e6b..92699a59 100644
--- a/arrow/flight/client.go
+++ b/arrow/flight/client.go
@@ -175,6 +175,19 @@ func CreateClientMiddleware(middleware 
CustomClientMiddleware) ClientMiddleware
                                desc:         desc,
                                finishFn:     finishFunc,
                        }
+                       if isHdrs {
+                               // Deliver response headers to the middleware 
as soon as they
+                               // are first retrieved via Header(), rather 
than waiting for
+                               // the stream to finish. This is necessary for 
streaming RPCs
+                               // like Handshake where the caller may inspect 
headers (e.g.
+                               // Set-Cookie) and issue subsequent RPCs before 
the stream
+                               // reaches io.EOF (e.g. when the server sends a 
response
+                               // payload that causes Recv to return a message 
instead of
+                               // EOF). See GH-755.
+                               newCS.onHeaders = func(md metadata.MD) {
+                                       hdrs.HeadersReceived(csCtx, md)
+                               }
+                       }
                        // The `ClientStream` interface allows one to omit 
calling `Recv` if it's
                        // known that the result will be `io.EOF`. See
                        // http://stackoverflow.com/q/42915337
@@ -193,12 +206,24 @@ type clientStream struct {
        grpc.ClientStream
        desc     *grpc.StreamDesc
        finishFn func(error)
+
+       // onHeaders, when non-nil, is invoked at most once with the response
+       // metadata the first time Header() returns successfully. It allows
+       // middleware (e.g. cookie middleware) to observe server headers as
+       // soon as they arrive on streaming RPCs, rather than waiting for the
+       // stream to finish via finishFn. See GH-755.
+       onHeaders       func(md metadata.MD)
+       headersObserved atomic.Bool
 }
 
 func (cs *clientStream) Header() (metadata.MD, error) {
        md, err := cs.ClientStream.Header()
        if err != nil {
                cs.finishFn(err)
+               return md, err
+       }
+       if cs.onHeaders != nil && cs.headersObserved.CompareAndSwap(false, 
true) {
+               cs.onHeaders(md)
        }
        return md, err
 }
diff --git a/arrow/flight/flightsql/driver/driver_test.go 
b/arrow/flight/flightsql/driver/driver_test.go
index 39d9dfd9..82c66753 100644
--- a/arrow/flight/flightsql/driver/driver_test.go
+++ b/arrow/flight/flightsql/driver/driver_test.go
@@ -1819,6 +1819,11 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx 
context.Context, qry flight
                if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
                        return nil, errors.New("parameter schema: unexpected")
                }
+               // See GH-35328: drain remaining batches before returning to 
avoid
+               // the io.EOF race between server close and client Write. The 
other
+               // success path below already does this; this branch must too.
+               for r.Next() {
+               }
                return qry.GetPreparedStatementHandle(), nil
        }
 
diff --git a/arrow/flight/handshake_cookie_test.go 
b/arrow/flight/handshake_cookie_test.go
new file mode 100644
index 00000000..c5c8464e
--- /dev/null
+++ b/arrow/flight/handshake_cookie_test.go
@@ -0,0 +1,324 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight_test
+
+import (
+       "context"
+       "encoding/base64"
+       "errors"
+       "io"
+       "strings"
+       "sync"
+       "testing"
+
+       "github.com/apache/arrow-go/v18/arrow/flight"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/credentials/insecure"
+       "google.golang.org/grpc/metadata"
+)
+
+// handshakeCookieFlightServer is a flight server that emits Set-Cookie
+// response headers (and trailers) during Handshake, simulating a server
+// that creates a session during the authentication flow (see GH-755).
+type handshakeCookieFlightServer struct {
+       flight.BaseFlightServer
+
+       headerCookie     string // cookie attached via SendHeader during 
Handshake
+       trailerCookie    string // cookie attached via SetTrailer during 
Handshake
+       bearerToken      string // authorization header returned during 
Handshake
+       sendPayload      bool   // if true, server sends a HandshakeResponse 
payload before closing
+       mu               sync.Mutex
+       lastIncomingCook []string // incoming Cookie header values observed on 
ListFlights
+}
+
+func (h *handshakeCookieFlightServer) Handshake(stream 
flight.FlightService_HandshakeServer) error {
+       md := metadata.MD{}
+       if h.headerCookie != "" {
+               md.Append("set-cookie", h.headerCookie)
+       }
+       if h.bearerToken != "" {
+               md.Append("authorization", "Bearer "+h.bearerToken)
+       }
+       if len(md) > 0 {
+               if err := stream.SendHeader(md); err != nil {
+                       return err
+               }
+       }
+
+       if h.trailerCookie != "" {
+               stream.SetTrailer(metadata.Pairs("set-cookie", h.trailerCookie))
+       }
+
+       if h.sendPayload {
+               if err := stream.Send(&flight.HandshakeResponse{Payload: 
[]byte("handshake-ok")}); err != nil {
+                       return err
+               }
+       }
+
+       // Drain the client stream until it closes.
+       for {
+               if _, err := stream.Recv(); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               return nil
+                       }
+                       return err
+               }
+       }
+}
+
+func (h *handshakeCookieFlightServer) ListFlights(c *flight.Criteria, fs 
flight.FlightService_ListFlightsServer) error {
+       h.mu.Lock()
+       if md, ok := metadata.FromIncomingContext(fs.Context()); ok {
+               h.lastIncomingCook = append([]string(nil), md.Get("cookie")...)
+       } else {
+               h.lastIncomingCook = nil
+       }
+       h.mu.Unlock()
+       return nil
+}
+
+func (h *handshakeCookieFlightServer) observedCookies() []string {
+       h.mu.Lock()
+       defer h.mu.Unlock()
+       return append([]string(nil), h.lastIncomingCook...)
+}
+
+// TestHandshakeCookiePropagationViaAuthenticateBasicToken is a regression
+// test for GH-755. It asserts that Set-Cookie headers returned by a
+// Handshake/DoHandshake response are captured by the cookie middleware
+// and attached to subsequent requests.
+func TestHandshakeCookiePropagationViaAuthenticateBasicToken(t *testing.T) {
+       srv := &handshakeCookieFlightServer{
+               headerCookie: "session_id=sess_header_abc",
+               bearerToken:  "my-bearer-token",
+       }
+
+       s := flight.NewServerWithMiddleware(nil)
+       s.Init("localhost:0")
+       s.RegisterFlightService(srv)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+       client, err := flight.NewClientWithMiddleware(
+               s.Addr().String(),
+               nil,
+               []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+               creds,
+       )
+       require.NoError(t, err)
+       defer client.Close()
+
+       ctx, err := client.AuthenticateBasicToken(context.Background(), "user", 
"pass")
+       require.NoError(t, err)
+
+       // Make a follow-up RPC. The cookie middleware must have captured
+       // Set-Cookie from the Handshake response, and StartCall should
+       // attach it as a Cookie header on this call.
+       stream, err := client.ListFlights(ctx, &flight.Criteria{})
+       require.NoError(t, err)
+       for {
+               if _, err := stream.Recv(); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break
+                       }
+                       require.NoError(t, err)
+               }
+       }
+
+       cookies := srv.observedCookies()
+       require.Len(t, cookies, 1, "expected exactly one Cookie header, got 
%v", cookies)
+       assert.Contains(t, cookies[0], "session_id=sess_header_abc",
+               "cookie middleware should propagate Set-Cookie from Handshake 
response headers")
+}
+
+// TestHandshakeCookiePropagationFromTrailers ensures cookies delivered as
+// gRPC trailers (instead of initial metadata headers) are also captured
+// by the cookie middleware during Handshake.
+func TestHandshakeCookiePropagationFromTrailers(t *testing.T) {
+       srv := &handshakeCookieFlightServer{
+               trailerCookie: "session_id=sess_trailer_xyz",
+               bearerToken:   "my-bearer-token",
+       }
+
+       s := flight.NewServerWithMiddleware(nil)
+       s.Init("localhost:0")
+       s.RegisterFlightService(srv)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+       client, err := flight.NewClientWithMiddleware(
+               s.Addr().String(),
+               nil,
+               []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+               creds,
+       )
+       require.NoError(t, err)
+       defer client.Close()
+
+       ctx, err := client.AuthenticateBasicToken(context.Background(), "user", 
"pass")
+       require.NoError(t, err)
+
+       stream, err := client.ListFlights(ctx, &flight.Criteria{})
+       require.NoError(t, err)
+       for {
+               if _, err := stream.Recv(); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break
+                       }
+                       require.NoError(t, err)
+               }
+       }
+
+       cookies := srv.observedCookies()
+       require.Len(t, cookies, 1, "expected exactly one Cookie header, got 
%v", cookies)
+       assert.Contains(t, cookies[0], "session_id=sess_trailer_xyz",
+               "cookie middleware should propagate Set-Cookie from Handshake 
response trailers")
+}
+
+// TestHandshakeCookiePropagationWithServerPayload is the precise scenario
+// reported in GH-755. The server attaches a Set-Cookie header AND sends
+// back a HandshakeResponse payload. AuthenticateBasicToken only calls
+// stream.Recv() once, which returns the payload (not io.EOF), so the
+// streaming finishFn that would normally invoke HeadersReceived never
+// fires. The cookie middleware must still capture the header cookie.
+func TestHandshakeCookiePropagationWithServerPayload(t *testing.T) {
+       srv := &handshakeCookieFlightServer{
+               headerCookie: "session_id=sess_with_payload",
+               bearerToken:  "my-bearer-token",
+               sendPayload:  true,
+       }
+
+       s := flight.NewServerWithMiddleware(nil)
+       s.Init("localhost:0")
+       s.RegisterFlightService(srv)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+       client, err := flight.NewClientWithMiddleware(
+               s.Addr().String(),
+               nil,
+               []flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
+               creds,
+       )
+       require.NoError(t, err)
+       defer client.Close()
+
+       ctx, err := client.AuthenticateBasicToken(context.Background(), "user", 
"pass")
+       require.NoError(t, err)
+
+       stream, err := client.ListFlights(ctx, &flight.Criteria{})
+       require.NoError(t, err)
+       for {
+               if _, err := stream.Recv(); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break
+                       }
+                       require.NoError(t, err)
+               }
+       }
+
+       cookies := srv.observedCookies()
+       require.Len(t, cookies, 1,
+               "expected exactly one Cookie header, got %v (GH-755: cookie 
lost when Handshake returns a payload)", cookies)
+       assert.Contains(t, cookies[0], "session_id=sess_with_payload")
+}
+
+// TestHandshakeCookieProcessedBeforeRecv verifies cookies are captured
+// eagerly once stream.Header() returns successfully. This models the
+// scenario where an application-level Handshake flow inspects response
+// headers and makes further RPCs before draining the stream.
+func TestHandshakeCookieProcessedBeforeRecv(t *testing.T) {
+       srv := &handshakeCookieFlightServer{
+               headerCookie: "session_id=eager_capture",
+       }
+
+       s := flight.NewServerWithMiddleware(nil)
+       s.Init("localhost:0")
+       s.RegisterFlightService(srv)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       cookies := flight.NewCookieMiddleware()
+       creds := grpc.WithTransportCredentials(insecure.NewCredentials())
+       client, err := flight.NewClientWithMiddleware(
+               s.Addr().String(),
+               nil,
+               
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cookies)},
+               creds,
+       )
+       require.NoError(t, err)
+       defer client.Close()
+
+       // Drive the Handshake manually; inspect headers before calling Recv().
+       authCtx := metadata.AppendToOutgoingContext(context.Background(),
+               "Authorization", "Basic 
"+base64.RawStdEncoding.EncodeToString([]byte("user:pass")))
+
+       stream, err := client.Handshake(authCtx)
+       require.NoError(t, err)
+       require.NoError(t, stream.CloseSend())
+
+       hdr, err := stream.Header()
+       require.NoError(t, err)
+       require.Contains(t, strings.Join(hdr.Get("set-cookie"), ","), 
"eager_capture")
+
+       // Clone the middleware while the original Handshake stream is still
+       // open. If cookies were processed eagerly from the header, the clone
+       // should already contain the session cookie.
+       cloned := cookies.Clone()
+
+       // Using the clone, make a unary-ish request against a second client
+       // to observe the outgoing Cookie header.
+       clientB, err := flight.NewClientWithMiddleware(
+               s.Addr().String(),
+               nil,
+               
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cloned)},
+               creds,
+       )
+       require.NoError(t, err)
+       defer clientB.Close()
+
+       ls, err := clientB.ListFlights(context.Background(), &flight.Criteria{})
+       require.NoError(t, err)
+       for {
+               if _, err := ls.Recv(); err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break
+                       }
+                       require.NoError(t, err)
+               }
+       }
+
+       got := srv.observedCookies()
+       require.Len(t, got, 1, "expected cloned middleware to send cookie from 
eagerly captured Handshake header, got %v", got)
+       assert.Contains(t, got[0], "session_id=eager_capture")
+
+       // Clean up original stream.
+       for {
+               if _, err := stream.Recv(); err != nil {
+                       break
+               }
+       }
+}

Reply via email to