This is an automated email from the ASF dual-hosted git repository.

xiazcy pushed a commit to branch go-http-fix
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 7da0e97fa2faa5000e983f53e3753840e488eb12
Author: Yang Xia <[email protected]>
AuthorDate: Tue Mar 24 22:45:32 2026 -0700

    Update order of interceptor and serialization of request
---
 gremlin-go/driver/auth.go             |  17 +-
 gremlin-go/driver/auth_test.go        |  26 +--
 gremlin-go/driver/connection.go       | 121 +++++------
 gremlin-go/driver/connection_test.go  |  33 +++
 gremlin-go/driver/interceptor.go      | 113 ++++++++++
 gremlin-go/driver/interceptor_test.go | 380 ++++++++++++++++++++++++++++++++++
 6 files changed, 602 insertions(+), 88 deletions(-)

diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go
index 74ca43a444..2f6f8b9a4e 100644
--- a/gremlin-go/driver/auth.go
+++ b/gremlin-go/driver/auth.go
@@ -22,6 +22,7 @@ package gremlingo
 import (
        "context"
        "encoding/base64"
+       "fmt"
        "sync"
        "time"
 
@@ -39,17 +40,17 @@ func BasicAuth(username, password string) 
RequestInterceptor {
        }
 }
 
-// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
+// SigV4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
 // It uses the default AWS credential chain (env vars, shared config, IAM 
role, etc.)
-func Sigv4Auth(region, service string) RequestInterceptor {
-       return Sigv4AuthWithCredentials(region, service, nil)
+func SigV4Auth(region, service string) RequestInterceptor {
+       return SigV4AuthWithCredentials(region, service, nil)
 }
 
-// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests 
using AWS SigV4
+// SigV4AuthWithCredentials returns a RequestInterceptor that signs requests 
using AWS SigV4
 // with the provided credentials provider. If provider is nil, uses default 
credential chain.
 //
 // Caches the signer and credentials provider for efficiency.
-func Sigv4AuthWithCredentials(region, service string, credentialsProvider 
aws.CredentialsProvider) RequestInterceptor {
+func SigV4AuthWithCredentials(region, service string, credentialsProvider 
aws.CredentialsProvider) RequestInterceptor {
        // Create signer once - it's stateless and safe to reuse
        signer := v4.NewSigner()
 
@@ -59,6 +60,12 @@ func Sigv4AuthWithCredentials(region, service string, 
credentialsProvider aws.Cr
        var providerErr error
 
        return func(req *HttpRequest) error {
+               // SigV4 requires serialized body bytes to compute the payload 
hash.
+               if _, ok := req.Body.([]byte); !ok {
+                       return fmt.Errorf("SigV4 signing requires serialized 
body bytes ([]byte); got %T. "+
+                               "Place SigV4Auth after serialization in the 
interceptor chain", req.Body)
+               }
+
                ctx := context.Background()
 
                // Resolve credentials provider once if not provided
diff --git a/gremlin-go/driver/auth_test.go b/gremlin-go/driver/auth_test.go
index ba60f6e6c5..7ec4079b6e 100644
--- a/gremlin-go/driver/auth_test.go
+++ b/gremlin-go/driver/auth_test.go
@@ -30,7 +30,7 @@ import (
 )
 
 func createMockRequest() *HttpRequest {
-       req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin";)
+       req, _ := NewHttpRequest("POST", "https://test_url:8182/gremlin";)
        req.Headers.Set("Content-Type", graphBinaryMimeType)
        req.Headers.Set("Accept", graphBinaryMimeType)
        req.Body = []byte(`{"gremlin":"g.V()"}`)
@@ -72,24 +72,24 @@ func (m *mockCredentialsProvider) Retrieve(ctx 
context.Context) (aws.Credentials
        }, nil
 }
 
-func TestSigv4Auth(t *testing.T) {
+func TestSigV4Auth(t *testing.T) {
        t.Run("adds signed headers", func(t *testing.T) {
                req := createMockRequest()
                assert.Empty(t, req.Headers.Get("Authorization"))
                assert.Empty(t, req.Headers.Get("X-Amz-Date"))
 
                provider := &mockCredentialsProvider{
-                       accessKey: "MOCK_ACCESS_KEY",
-                       secretKey: "MOCK_SECRET_KEY",
+                       accessKey: "MOCK_ID",
+                       secretKey: "MOCK_KEY",
                }
-               interceptor := Sigv4AuthWithCredentials("us-west-2", 
"neptune-db", provider)
+               interceptor := SigV4AuthWithCredentials("gremlin-east-1", 
"tinkerpop-sigv4", provider)
                err := interceptor(req)
 
                assert.NoError(t, err)
                assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"))
                authHeader := req.Headers.Get("Authorization")
-               assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 
Credential=MOCK_ACCESS_KEY"))
-               assert.Contains(t, authHeader, 
"us-west-2/neptune-db/aws4_request")
+               assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 
Credential=MOCK_ID"))
+               assert.Contains(t, authHeader, 
"gremlin-east-1/tinkerpop-sigv4/aws4_request")
                assert.Contains(t, authHeader, "Signature=")
        })
 
@@ -98,17 +98,17 @@ func TestSigv4Auth(t *testing.T) {
                assert.Empty(t, req.Headers.Get("X-Amz-Security-Token"))
 
                provider := &mockCredentialsProvider{
-                       accessKey:    "MOCK_ACCESS_KEY",
-                       secretKey:    "MOCK_SECRET_KEY",
-                       sessionToken: "MOCK_SESSION_TOKEN",
+                       accessKey:    "MOCK_ID",
+                       secretKey:    "MOCK_KEY",
+                       sessionToken: "MOCK_TOKEN",
                }
-               interceptor := Sigv4AuthWithCredentials("us-west-2", 
"neptune-db", provider)
+               interceptor := SigV4AuthWithCredentials("gremlin-east-1", 
"tinkerpop-sigv4", provider)
                err := interceptor(req)
 
                assert.NoError(t, err)
-               assert.Equal(t, "MOCK_SESSION_TOKEN", 
req.Headers.Get("X-Amz-Security-Token"))
+               assert.Equal(t, "MOCK_TOKEN", 
req.Headers.Get("X-Amz-Security-Token"))
                authHeader := req.Headers.Get("Authorization")
                assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 
Credential="))
-               assert.Contains(t, authHeader, "Signature=")
+               assert.Contains(t, authHeader, 
"gremlin-east-1/tinkerpop-sigv4/aws4_request")
        })
 }
diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go
index 5086965a8d..562bad80a3 100644
--- a/gremlin-go/driver/connection.go
+++ b/gremlin-go/driver/connection.go
@@ -32,59 +32,6 @@ import (
        "time"
 )
 
-// Common HTTP header keys
-const (
-       HeaderContentType    = "Content-Type"
-       HeaderAccept         = "Accept"
-       HeaderUserAgent      = "User-Agent"
-       HeaderAcceptEncoding = "Accept-Encoding"
-       HeaderAuthorization  = "Authorization"
-)
-
-// HttpRequest represents an HTTP request that can be modified by interceptors.
-type HttpRequest struct {
-       Method  string
-       URL     *url.URL
-       Headers http.Header
-       Body    []byte
-}
-
-// NewHttpRequest creates a new HttpRequest with the given method and URL.
-func NewHttpRequest(method, rawURL string) (*HttpRequest, error) {
-       u, err := url.Parse(rawURL)
-       if err != nil {
-               return nil, err
-       }
-       return &HttpRequest{
-               Method:  method,
-               URL:     u,
-               Headers: make(http.Header),
-       }, nil
-}
-
-// ToStdRequest converts HttpRequest to a standard http.Request for signing.
-// Returns nil if the request cannot be created (invalid method or URL).
-func (r *HttpRequest) ToStdRequest() (*http.Request, error) {
-       req, err := http.NewRequest(r.Method, r.URL.String(), 
bytes.NewReader(r.Body))
-       if err != nil {
-               return nil, err
-       }
-       req.Header = r.Headers
-       return req, nil
-}
-
-// PayloadHash returns the SHA256 hash of the request body for SigV4 signing.
-func (r *HttpRequest) PayloadHash() string {
-       if len(r.Body) == 0 {
-               return 
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of 
empty string
-       }
-       h := sha256.Sum256(r.Body)
-       return hex.EncodeToString(h[:])
-}
-
-// RequestInterceptor is a function that modifies an HTTP request before it is 
sent.
-type RequestInterceptor func(*HttpRequest) error
-
 // connectionSettings holds configuration for the connection.
 type connectionSettings struct {
        tlsConfig                *tls.Config
@@ -174,18 +121,12 @@ func (c *connection) AddInterceptor(interceptor 
RequestInterceptor) {
 func (c *connection) submit(req *request) (ResultSet, error) {
        rs := newChannelResultSet()
 
-       data, err := c.serializer.SerializeMessage(req)
-       if err != nil {
-               rs.Close()
-               return rs, err
-       }
-
-       go c.executeAndStream(data, rs)
+       go c.executeAndStream(req, rs)
 
        return rs, nil
 }
 
-func (c *connection) executeAndStream(data []byte, rs ResultSet) {
+func (c *connection) executeAndStream(req *request, rs ResultSet) {
        defer rs.Close()
 
        // Create HttpRequest for interceptors
@@ -195,12 +136,15 @@ func (c *connection) executeAndStream(data []byte, rs 
ResultSet) {
                rs.setError(err)
                return
        }
-       httpReq.Body = data
 
        // Set default headers before interceptors
        c.setHttpRequestHeaders(httpReq)
 
-       // Apply interceptors
+       // Set Body to the raw *request so interceptors can inspect/modify it
+       httpReq.Body = req
+
+       // Apply interceptors — they see *request in Body (pre-serialization).
+       // Interceptors may replace Body with []byte, io.Reader, or 
*http.Request.
        for _, interceptor := range c.interceptors {
                if err := interceptor(httpReq); err != nil {
                        c.logHandler.logf(Error, failedToSendRequest, 
err.Error())
@@ -209,16 +153,53 @@ func (c *connection) executeAndStream(data []byte, rs 
ResultSet) {
                }
        }
 
-       // Create actual http.Request from HttpRequest
-       req, err := http.NewRequest(httpReq.Method, httpReq.URL.String(), 
bytes.NewReader(httpReq.Body))
-       if err != nil {
-               c.logHandler.logf(Error, failedToSendRequest, err.Error())
-               rs.setError(err)
+       // After interceptors, serialize if Body is still *request
+       if r, ok := httpReq.Body.(*request); ok {
+               if c.serializer != nil {
+                       data, err := c.serializer.SerializeMessage(r)
+                       if err != nil {
+                               c.logHandler.logf(Error, failedToSendRequest, 
err.Error())
+                               rs.setError(err)
+                               return
+                       }
+                       httpReq.Body = data
+               } else {
+                       errMsg := "request body was not serialized; either 
provide a serializer or add an interceptor that serializes the request"
+                       c.logHandler.logf(Error, failedToSendRequest, errMsg)
+                       rs.setError(fmt.Errorf("%s", errMsg))
+                       return
+               }
+       }
+
+       // Create actual http.Request from HttpRequest based on Body type
+       var httpGoReq *http.Request
+       switch body := httpReq.Body.(type) {
+       case []byte:
+               httpGoReq, err = http.NewRequest(httpReq.Method, 
httpReq.URL.String(), bytes.NewReader(body))
+               if err != nil {
+                       c.logHandler.logf(Error, failedToSendRequest, 
err.Error())
+                       rs.setError(err)
+                       return
+               }
+               httpGoReq.Header = httpReq.Headers
+       case io.Reader:
+               httpGoReq, err = http.NewRequest(httpReq.Method, 
httpReq.URL.String(), body)
+               if err != nil {
+                       c.logHandler.logf(Error, failedToSendRequest, 
err.Error())
+                       rs.setError(err)
+                       return
+               }
+               httpGoReq.Header = httpReq.Headers
+       case *http.Request:
+               httpGoReq = body
+       default:
+               errMsg := fmt.Sprintf("unsupported body type after 
interceptors: %T", body)
+               c.logHandler.logf(Error, failedToSendRequest, errMsg)
+               rs.setError(fmt.Errorf("%s", errMsg))
                return
        }
-       req.Header = httpReq.Headers
 
-       resp, err := c.httpClient.Do(req)
+       resp, err := c.httpClient.Do(httpGoReq)
        if err != nil {
                c.logHandler.logf(Error, failedToSendRequest, err.Error())
                rs.setError(err)
diff --git a/gremlin-go/driver/connection_test.go 
b/gremlin-go/driver/connection_test.go
index a0e8414223..4bb2de5cce 100644
--- a/gremlin-go/driver/connection_test.go
+++ b/gremlin-go/driver/connection_test.go
@@ -1261,3 +1261,36 @@ func TestDriverRemoteConnectionSettingsWiring(t 
*testing.T) {
                assert.Equal(t, 180*time.Second, transport.IdleConnTimeout)
        })
 }
+
+// TestConnectionWithMockServer_BasicAuth verifies that BasicAuth interceptor 
sets the correct
+// Authorization header and the body is still valid serialized bytes.
+func TestConnectionWithMockServer_BasicAuth(t *testing.T) {
+       var capturedAuthHeader string
+       var capturedBody []byte
+
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               capturedAuthHeader = r.Header.Get("Authorization")
+               body, err := io.ReadAll(r.Body)
+               if err == nil {
+                       capturedBody = body
+               }
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+       conn.AddInterceptor(BasicAuth("testuser", "testpass"))
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       // BasicAuth should set Authorization header with 
base64("testuser:testpass") = "dGVzdHVzZXI6dGVzdHBhc3M="
+       assert.Equal(t, "Basic dGVzdHVzZXI6dGVzdHBhc3M=", capturedAuthHeader,
+               "Authorization header should be Basic 
base64(testuser:testpass)")
+
+       // Body should still be valid serialized bytes
+       assert.NotEmpty(t, capturedBody, "serialized body should be non-empty 
with BasicAuth")
+       assert.Equal(t, byte(0x81), capturedBody[0],
+               "body should start with GraphBinary version byte 0x81")
+}
diff --git a/gremlin-go/driver/interceptor.go b/gremlin-go/driver/interceptor.go
new file mode 100644
index 0000000000..a5d63a31be
--- /dev/null
+++ b/gremlin-go/driver/interceptor.go
@@ -0,0 +1,113 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package gremlingo
+
+import (
+       "bytes"
+       "crypto/sha256"
+       "encoding/hex"
+       "io"
+       "net/http"
+       "net/url"
+)
+
+// Common HTTP header keys
+const (
+       HeaderContentType    = "Content-Type"
+       HeaderAccept         = "Accept"
+       HeaderUserAgent      = "User-Agent"
+       HeaderAcceptEncoding = "Accept-Encoding"
+       HeaderAuthorization  = "Authorization"
+)
+
+// HttpRequest represents an HTTP request that can be modified by interceptors.
+type HttpRequest struct {
+       Method  string
+       URL     *url.URL
+       Headers http.Header
+       Body    any
+}
+
+// NewHttpRequest creates a new HttpRequest with the given method and URL.
+func NewHttpRequest(method, rawURL string) (*HttpRequest, error) {
+       u, err := url.Parse(rawURL)
+       if err != nil {
+               return nil, err
+       }
+       return &HttpRequest{
+               Method:  method,
+               URL:     u,
+               Headers: make(http.Header),
+       }, nil
+}
+
+// ToStdRequest converts HttpRequest to a standard http.Request for signing.
+// Returns nil if the request cannot be created (invalid method or URL).
+func (r *HttpRequest) ToStdRequest() (*http.Request, error) {
+       var body io.Reader
+       switch b := r.Body.(type) {
+       case []byte:
+               body = bytes.NewReader(b)
+       default:
+               body = http.NoBody
+       }
+       req, err := http.NewRequest(r.Method, r.URL.String(), body)
+       if err != nil {
+               return nil, err
+       }
+       req.Header = r.Headers
+       return req, nil
+}
+
+// PayloadHash returns the SHA256 hash of the request body for SigV4 signing.
+func (r *HttpRequest) PayloadHash() string {
+       switch b := r.Body.(type) {
+       case []byte:
+               if len(b) == 0 {
+                       return 
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of 
empty string
+               }
+               h := sha256.Sum256(b)
+               return hex.EncodeToString(h[:])
+       default:
+               return 
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of 
empty string
+       }
+}
+
+// RequestInterceptor is a function that modifies an HTTP request before it is 
sent.
+type RequestInterceptor func(*HttpRequest) error
+
+// SerializeRequest returns a RequestInterceptor that serializes the raw 
*request body
+// to GraphBinary []byte. Place this before auth interceptors (e.g., 
SigV4Auth) that
+// need the serialized body bytes.
+func SerializeRequest() RequestInterceptor {
+       serializer := newGraphBinarySerializer(nil)
+       return func(req *HttpRequest) error {
+               r, ok := req.Body.(*request)
+               if !ok {
+                       return nil // already serialized or not a *request
+               }
+               data, err := serializer.SerializeMessage(r)
+               if err != nil {
+                       return err
+               }
+               req.Body = data
+               return nil
+       }
+}
diff --git a/gremlin-go/driver/interceptor_test.go 
b/gremlin-go/driver/interceptor_test.go
new file mode 100644
index 0000000000..78e36a0b95
--- /dev/null
+++ b/gremlin-go/driver/interceptor_test.go
@@ -0,0 +1,380 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package gremlingo
+
+import (
+       "bytes"
+       "fmt"
+       "io"
+       "net/http"
+       "net/http/httptest"
+       "reflect"
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+)
+
+// TestInterceptorReceivesRawRequest verifies that interceptors receive the 
raw *request
+// object in HttpRequest.Body, not serialized []byte.
+func TestInterceptorReceivesRawRequest(t *testing.T) {
+       // Mock server that accepts the request (we don't care about the 
response for this test)
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       // Create connection with non-nil serializer (default behavior of 
newConnection)
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       var capturedBodyType reflect.Type
+       var capturedBody interface{}
+
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               capturedBody = req.Body
+               capturedBodyType = reflect.TypeOf(req.Body)
+               return nil
+       })
+
+       // Submit a request with a known gremlin query
+       rs, err := conn.submit(&request{gremlin: "g.V().count()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain result set
+
+       assert.Equal(t, reflect.TypeOf((*request)(nil)), capturedBodyType,
+               "interceptor should receive *request in Body, got %v", 
capturedBodyType)
+
+       r, typeAssertOk := capturedBody.(*request)
+       assert.True(t, typeAssertOk, "interceptor should be able to type-assert 
Body to *request")
+       if typeAssertOk {
+               assert.Equal(t, "g.V().count()", r.gremlin,
+                       "interceptor should be able to read the gremlin field 
from the raw request")
+       }
+}
+
+// TestSigV4AuthWithSerializeInterceptor verifies that SerializeRequest() + 
SigV4Auth
+// works in a chain. SerializeRequest converts *request to []byte, then 
SigV4Auth
+// can sign the serialized body.
+func TestSigV4AuthWithSerializeInterceptor(t *testing.T) {
+       var capturedHeaders http.Header
+       var capturedBody []byte
+
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               capturedHeaders = r.Header.Clone()
+               body, err := io.ReadAll(r.Body)
+               if err == nil {
+                       capturedBody = body
+               }
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       mockProvider := &mockCredentialsProvider{
+               accessKey: "MOCK_ID",
+               secretKey: "MOCK_KEY",
+       }
+
+       conn.AddInterceptor(SerializeRequest())
+       conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1", 
"tinkerpop-sigv4", mockProvider))
+
+       rs, err := conn.submit(&request{gremlin: "g.V().count()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       // SigV4 should have added Authorization and X-Amz-Date headers
+       assert.NotEmpty(t, capturedHeaders.Get("Authorization"),
+               "SigV4Auth should set Authorization header after 
SerializeRequest")
+       assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"),
+               "SigV4Auth should set X-Amz-Date header")
+       assert.Contains(t, capturedHeaders.Get("Authorization"), 
"AWS4-HMAC-SHA256",
+               "Authorization header should use AWS4-HMAC-SHA256 signing 
algorithm")
+
+       // Body should be valid serialized bytes
+       assert.NotEmpty(t, capturedBody, "body should be non-empty serialized 
bytes")
+       assert.Equal(t, byte(0x81), capturedBody[0],
+               "body should start with GraphBinary version byte 0x81")
+}
+
+// TestMultipleInterceptors_SerializeThenAuth verifies that a custom 
interceptor can
+// modify the raw request, then SerializeRequest serializes it, then BasicAuth 
adds headers.
+func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) {
+       var capturedAuthHeader string
+       var capturedBody []byte
+
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               capturedAuthHeader = r.Header.Get("Authorization")
+               body, err := io.ReadAll(r.Body)
+               if err == nil {
+                       capturedBody = body
+               }
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       // Custom interceptor that modifies the raw request fields
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               r, ok := req.Body.(*request)
+               if !ok {
+                       return fmt.Errorf("expected *request, got %T", req.Body)
+               }
+               // Add a custom field to the request
+               r.fields["customField"] = "customValue"
+               return nil
+       })
+
+       // SerializeRequest converts the modified *request to []byte
+       conn.AddInterceptor(SerializeRequest())
+
+       // BasicAuth adds the Authorization header (works on any body type)
+       conn.AddInterceptor(BasicAuth("admin", "secret"))
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       // BasicAuth should have set the Authorization header
+       assert.Equal(t, "Basic YWRtaW46c2VjcmV0", capturedAuthHeader,
+               "Authorization header should be Basic base64(admin:secret)")
+
+       // Body should be valid serialized bytes (from SerializeRequest)
+       assert.NotEmpty(t, capturedBody, "body should be non-empty serialized 
bytes")
+       assert.Equal(t, byte(0x81), capturedBody[0],
+               "body should start with GraphBinary version byte 0x81")
+}
+
+// TestInterceptor_IoReaderBody verifies that an interceptor can set Body to 
an io.Reader
+// and the request is sent correctly.
+func TestInterceptor_IoReaderBody(t *testing.T) {
+       var capturedBody []byte
+
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               body, err := io.ReadAll(r.Body)
+               if err == nil {
+                       capturedBody = body
+               }
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       customPayload := []byte("custom binary payload")
+
+       // Interceptor replaces Body with an io.Reader
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               req.Body = bytes.NewReader(customPayload)
+               return nil
+       })
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       // The server should receive the custom payload from the io.Reader
+       assert.Equal(t, customPayload, capturedBody,
+               "server should receive the custom payload set via io.Reader")
+}
+
+// TestInterceptor_NilSerializerNoSerialization verifies that when serializer 
is nil
+// and no interceptor serializes, the correct error message is produced.
+func TestInterceptor_NilSerializerNoSerialization(t *testing.T) {
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+       conn.serializer = nil // explicitly nil serializer
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+
+       _, _ = rs.All() // drain — this triggers the async executeAndStream
+       rsErr := rs.GetError()
+       require.Error(t, rsErr, "should get an error when serializer is nil and 
no interceptor serializes")
+       assert.Contains(t, rsErr.Error(), "request body was not serialized",
+               "error message should indicate the body was not serialized")
+}
+
+// TestInterceptor_HttpRequestBody verifies that an interceptor can set Body 
to *http.Request
+// and the driver sends it directly, using the *http.Request's headers and 
body instead of
+// HttpRequest.Headers.
+func TestInterceptor_HttpRequestBody(t *testing.T) {
+       var capturedHeaders http.Header
+       var capturedBody []byte
+
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               capturedHeaders = r.Header.Clone()
+               body, err := io.ReadAll(r.Body)
+               if err == nil {
+                       capturedBody = body
+               }
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       customBody := []byte("custom-http-request-body")
+
+       // Interceptor builds a complete *http.Request and sets it as Body
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               httpGoReq, err := http.NewRequest(http.MethodPost, 
req.URL.String(), bytes.NewReader(customBody))
+               if err != nil {
+                       return err
+               }
+               httpGoReq.Header.Set("X-Custom-Header", "custom-value")
+               httpGoReq.Header.Set("Content-Type", "application/octet-stream")
+               req.Body = httpGoReq
+               return nil
+       })
+
+       // Also set a header on HttpRequest.Headers that should NOT appear,
+       // because *http.Request body bypasses HttpRequest.Headers
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               req.Headers.Set("X-Should-Not-Appear", "ignored")
+               return nil
+       })
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       // The server should receive headers from the *http.Request, not from 
HttpRequest.Headers
+       assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"),
+               "server should receive custom header from *http.Request")
+       assert.Equal(t, "application/octet-stream", 
capturedHeaders.Get("Content-Type"),
+               "server should receive Content-Type from *http.Request")
+       assert.Empty(t, capturedHeaders.Get("X-Should-Not-Appear"),
+               "headers set on HttpRequest.Headers should not appear when Body 
is *http.Request")
+
+       // The server should receive the body from the *http.Request
+       assert.Equal(t, customBody, capturedBody,
+               "server should receive body from the *http.Request")
+}
+
+// TestInterceptor_ErrorPropagation verifies that when an interceptor returns 
an error,
+// it is propagated to the ResultSet.
+func TestInterceptor_ErrorPropagation(t *testing.T) {
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               return fmt.Errorf("interceptor failed")
+       })
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+
+       _, _ = rs.All() // drain — triggers async executeAndStream
+       rsErr := rs.GetError()
+       require.Error(t, rsErr, "interceptor error should propagate to 
ResultSet")
+       assert.Contains(t, rsErr.Error(), "interceptor failed",
+               "ResultSet error should contain the interceptor's error 
message")
+}
+
+// TestInterceptor_UnsupportedBodyType verifies that setting Body to an 
unsupported type
+// (e.g., an int) produces the "unsupported body type" error.
+func TestInterceptor_UnsupportedBodyType(t *testing.T) {
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       // Interceptor sets Body to an unsupported type
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               req.Body = 42
+               return nil
+       })
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+
+       _, _ = rs.All() // drain
+       rsErr := rs.GetError()
+       require.Error(t, rsErr, "unsupported body type should produce an error")
+       assert.Contains(t, rsErr.Error(), "unsupported body type",
+               "error message should indicate unsupported body type")
+}
+
+// TestInterceptor_ChainOrder verifies that interceptors run in the order they 
are added.
+func TestInterceptor_ChainOrder(t *testing.T) {
+       server := httptest.NewServer(http.HandlerFunc(func(w 
http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       }))
+       defer server.Close()
+
+       conn := newConnection(newTestLogHandler(), server.URL, 
&connectionSettings{})
+
+       var order []int
+
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               order = append(order, 1)
+               return nil
+       })
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               order = append(order, 2)
+               return nil
+       })
+       conn.AddInterceptor(func(req *HttpRequest) error {
+               order = append(order, 3)
+               return nil
+       })
+
+       rs, err := conn.submit(&request{gremlin: "g.V()", fields: 
map[string]interface{}{}})
+       require.NoError(t, err)
+       _, _ = rs.All() // drain
+
+       assert.Equal(t, []int{1, 2, 3}, order,
+               "interceptors should run in the order they were added")
+}
+
+// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error 
when Body
+// is not []byte (e.g., an unserialized *request).
+func TestSigV4Auth_RejectsNonByteBody(t *testing.T) {
+       provider := &mockCredentialsProvider{
+               accessKey: "MOCK_ID",
+               secretKey: "MOCK_KEY",
+       }
+       interceptor := SigV4AuthWithCredentials("gremlin-east-1", 
"tinkerpop-sigv4", provider)
+
+       req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin";)
+       require.NoError(t, err)
+       req.Headers.Set("Content-Type", graphBinaryMimeType)
+       req.Headers.Set("Accept", graphBinaryMimeType)
+
+       // Set Body to *request (not []byte) — SigV4Auth should reject this
+       req.Body = &request{gremlin: "g.V()", fields: map[string]interface{}{}}
+
+       err = interceptor(req)
+       require.Error(t, err, "SigV4Auth should reject non-[]byte body")
+       assert.Contains(t, err.Error(), "SigV4 signing requires serialized body 
bytes",
+               "error message should indicate SigV4 requires serialized body 
bytes")
+}

Reply via email to