This is an automated email from the ASF dual-hosted git repository.

liujun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 72f80ae88 fix: Server return with Attachment (#2648)
72f80ae88 is described below

commit 72f80ae88cb409a63103d0ec7b6d250a5c9f5863
Author: YarBor <110281261+yar...@users.noreply.github.com>
AuthorDate: Fri Apr 26 09:42:14 2024 +0800

    fix: Server return with Attachment (#2648)
---
 client/client.go                                   |  1 +
 protocol/triple/triple_invoker.go                  | 59 +++++++++--------
 protocol/triple/triple_invoker_test.go             | 22 ++++---
 .../triple/triple_protocol/duplex_http_call.go     |  4 ++
 protocol/triple/triple_protocol/handler.go         | 28 ++++++--
 protocol/triple/triple_protocol/header.go          | 74 ++++++++++++++++------
 6 files changed, 128 insertions(+), 60 deletions(-)

diff --git a/client/client.go b/client/client.go
index 98ab66733..a2bcaef1e 100644
--- a/client/client.go
+++ b/client/client.go
@@ -122,6 +122,7 @@ func (cli *Client) dial(interfaceName string, info 
*ClientInfo, opts ...Referenc
 
        return &Connection{refOpts: newRefOpts}, nil
 }
+
 func generateInvocation(methodName string, reqs []interface{}, resp 
interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) {
        var paramsRawVals []interface{}
        for _, req := range reqs {
diff --git a/protocol/triple/triple_invoker.go 
b/protocol/triple/triple_invoker.go
index e08778429..4c1eb6868 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, 
invocation protocol.Invocat
                return &result
        }
 
-       ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), 
invocation)
+       callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), 
invocation)
        if err != nil {
                result.SetError(err)
                return &result
        }
+
+       ctx, err = mergeAttachmentToOutgoing(ctx, invocation)
+       if err != nil {
+               result.SetError(err)
+               return &result
+       }
+
        inRawLen := len(inRaw)
 
        if !ti.clientManager.isIDL {
@@ -136,16 +143,33 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, 
invocation protocol.Invocat
        return &result
 }
 
+func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) 
(context.Context, error) {
+       for key, valRaw := range inv.Attachments() {
+               if str, ok := valRaw.(string); ok {
+                       ctx = tri.AppendToOutgoingContext(ctx, key, str)
+                       continue
+               }
+               if strs, ok := valRaw.([]string); ok {
+                       for _, str := range strs {
+                               ctx = tri.AppendToOutgoingContext(ctx, key, str)
+                       }
+                       continue
+               }
+               return ctx, fmt.Errorf("triple attachments value with key = %s 
is invalid, which should be string or []string", key)
+       }
+       return ctx, nil
+}
+
 // parseInvocation retrieves information from invocation.
 // it returns ctx, callType, inRaw, method, error
-func parseInvocation(ctx context.Context, url *common.URL, invocation 
protocol.Invocation) (context.Context, string, []interface{}, string, error) {
+func parseInvocation(ctx context.Context, url *common.URL, invocation 
protocol.Invocation) (string, []interface{}, string, error) {
        callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
        if !ok {
-               return nil, "", nil, "", errors.New("miss CallType in 
invocation to invoke TripleInvoker")
+               return "", nil, "", errors.New("miss CallType in invocation to 
invoke TripleInvoker")
        }
        callType, ok := callTypeRaw.(string)
        if !ok {
-               return nil, "", nil, "", fmt.Errorf("CallType should be string, 
but got %v", callTypeRaw)
+               return "", nil, "", fmt.Errorf("CallType should be string, but 
got %v", callTypeRaw)
        }
        // please refer to methods of client.Client or code generated by new 
triple for the usage of inRaw and inRawLen
        // e.g. Client.CallUnary(... req, resp []interface, ...)
@@ -153,19 +177,16 @@ func parseInvocation(ctx context.Context, url 
*common.URL, invocation protocol.I
        inRaw := invocation.ParameterRawValues()
        method := invocation.MethodName()
        if method == "" {
-               return nil, "", nil, "", errors.New("miss MethodName in 
invocation to invoke TripleInvoker")
+               return "", nil, "", errors.New("miss MethodName in invocation 
to invoke TripleInvoker")
        }
 
-       ctx, err := parseAttachments(ctx, url, invocation)
-       if err != nil {
-               return nil, "", nil, "", err
-       }
+       parseAttachments(ctx, url, invocation)
 
-       return ctx, callType, inRaw, method, nil
+       return callType, inRaw, method, nil
 }
 
 // parseAttachments retrieves attachments from users passed-in and URL, then 
injects them into ctx
-func parseAttachments(ctx context.Context, url *common.URL, invocation 
protocol.Invocation) (context.Context, error) {
+func parseAttachments(ctx context.Context, url *common.URL, invocation 
protocol.Invocation) {
        // retrieve users passed-in attachment
        attaRaw := ctx.Value(constant.AttachmentKey)
        if attaRaw != nil {
@@ -181,22 +202,6 @@ func parseAttachments(ctx context.Context, url 
*common.URL, invocation protocol.
                        invocation.SetAttachment(key, val)
                }
        }
-       // inject attachments
-       for key, valRaw := range invocation.Attachments() {
-               if str, ok := valRaw.(string); ok {
-                       ctx = tri.AppendToOutgoingContext(ctx, key, str)
-                       continue
-               }
-               if strs, ok := valRaw.([]string); ok {
-                       for _, str := range strs {
-                               ctx = tri.AppendToOutgoingContext(ctx, key, str)
-                       }
-                       continue
-               }
-               return nil, fmt.Errorf("triple attachments value with key = %s 
is invalid, which should be string or []string", key)
-       }
-
-       return ctx, nil
 }
 
 // IsAvailable get available status
diff --git a/protocol/triple/triple_invoker_test.go 
b/protocol/triple/triple_invoker_test.go
index 7d14d4239..e9dcc968c 100644
--- a/protocol/triple/triple_invoker_test.go
+++ b/protocol/triple/triple_invoker_test.go
@@ -19,6 +19,7 @@ package triple
 
 import (
        "context"
+       "net/http"
        "testing"
 
        "dubbo.apache.org/dubbo-go/v3/common"
@@ -35,7 +36,7 @@ func Test_parseInvocation(t *testing.T) {
                ctx    func() context.Context
                url    *common.URL
                invo   func() protocol.Invocation
-               expect func(t *testing.T, ctx context.Context, callType string, 
inRaw []interface{}, methodName string, err error)
+               expect func(t *testing.T, callType string, inRaw []interface{}, 
methodName string, err error)
        }{
                {
                        desc: "miss callType",
@@ -46,7 +47,7 @@ func Test_parseInvocation(t *testing.T) {
                        invo: func() protocol.Invocation {
                                return invocation.NewRPCInvocationWithOptions()
                        },
-                       expect: func(t *testing.T, ctx context.Context, 
callType string, inRaw []interface{}, methodName string, err error) {
+                       expect: func(t *testing.T, callType string, inRaw 
[]interface{}, methodName string, err error) {
                                assert.NotNil(t, err)
                        },
                },
@@ -61,7 +62,7 @@ func Test_parseInvocation(t *testing.T) {
                                iv.SetAttribute(constant.CallTypeKey, 1)
                                return iv
                        },
-                       expect: func(t *testing.T, ctx context.Context, 
callType string, inRaw []interface{}, methodName string, err error) {
+                       expect: func(t *testing.T, callType string, inRaw 
[]interface{}, methodName string, err error) {
                                assert.NotNil(t, err)
                        },
                },
@@ -76,7 +77,7 @@ func Test_parseInvocation(t *testing.T) {
                                iv.SetAttribute(constant.CallTypeKey, 
constant.CallUnary)
                                return iv
                        },
-                       expect: func(t *testing.T, ctx context.Context, 
callType string, inRaw []interface{}, methodName string, err error) {
+                       expect: func(t *testing.T, callType string, inRaw 
[]interface{}, methodName string, err error) {
                                assert.NotNil(t, err)
                        },
                },
@@ -84,8 +85,8 @@ func Test_parseInvocation(t *testing.T) {
 
        for _, test := range tests {
                t.Run(test.desc, func(t *testing.T) {
-                       ctx, callType, inRaw, methodName, err := 
parseInvocation(test.ctx(), test.url, test.invo())
-                       test.expect(t, ctx, callType, inRaw, methodName, err)
+                       callType, inRaw, methodName, err := 
parseInvocation(test.ctx(), test.url, test.invo())
+                       test.expect(t, callType, inRaw, methodName, err)
                })
        }
 }
@@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) {
                        },
                        expect: func(t *testing.T, ctx context.Context, err 
error) {
                                assert.Nil(t, err)
-                               header := tri.ExtractFromOutgoingContext(ctx)
+                               header := 
http.Header(tri.ExtractFromOutgoingContext(ctx))
                                assert.NotNil(t, header)
                                assert.Equal(t, "interface", 
header.Get(constant.InterfaceKey))
                                assert.Equal(t, "token", 
header.Get(constant.TokenKey))
@@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) {
                        },
                        expect: func(t *testing.T, ctx context.Context, err 
error) {
                                assert.Nil(t, err)
-                               header := tri.ExtractFromOutgoingContext(ctx)
+                               header := 
http.Header(tri.ExtractFromOutgoingContext(ctx))
                                assert.NotNil(t, header)
                                assert.Equal(t, "val1", header.Get("key1"))
                                assert.Equal(t, []string{"key2_1", "key2_2"}, 
header.Values("key2"))
@@ -157,7 +158,10 @@ func Test_parseAttachments(t *testing.T) {
 
        for _, test := range tests {
                t.Run(test.desc, func(t *testing.T) {
-                       ctx, err := parseAttachments(test.ctx(), test.url, 
test.invo())
+                       ctx := test.ctx()
+                       inv := test.invo()
+                       parseAttachments(ctx, test.url, inv)
+                       ctx, err := mergeAttachmentToOutgoing(ctx, inv)
                        test.expect(t, ctx, err)
                })
        }
diff --git a/protocol/triple/triple_protocol/duplex_http_call.go 
b/protocol/triple/triple_protocol/duplex_http_call.go
index 3865a56ab..6c4c02176 100644
--- a/protocol/triple/triple_protocol/duplex_http_call.go
+++ b/protocol/triple/triple_protocol/duplex_http_call.go
@@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error {
        if err := discard(d.response.Body); err != nil {
                return wrapIfRSTError(err)
        }
+       // Return incoming data via context, if set outgoing data.
+       if ExtractFromOutgoingContext(d.ctx) != nil {
+               newIncomingContext(d.ctx, d.ResponseTrailer())
+       }
        return wrapIfRSTError(d.response.Body.Close())
 }
 
diff --git a/protocol/triple/triple_protocol/handler.go 
b/protocol/triple/triple_protocol/handler.go
index 56f83b3fe..7a44f8535 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -112,6 +112,10 @@ func generateUnaryHandlerFunc(
                // merge headers
                mergeHeaders(conn.ResponseHeader(), response.Header())
                mergeHeaders(conn.ResponseTrailer(), response.Trailer())
+               //Write the server-side return-attachment-data in the tailer to 
send to the caller
+               if data := ExtractFromOutgoingContext(ctx); data != nil {
+                       mergeHeaders(conn.ResponseTrailer(), data)
+               }
                return conn.Send(response.Any())
        }
 
@@ -160,6 +164,9 @@ func generateClientStreamHandlerFunc(
                }
                mergeHeaders(conn.ResponseHeader(), res.header)
                mergeHeaders(conn.ResponseTrailer(), res.trailer)
+               if outgoingData := ExtractFromOutgoingContext(ctx); 
outgoingData != nil {
+                       mergeHeaders(conn.ResponseTrailer(), outgoingData)
+               }
                return conn.Send(res.Msg)
        }
        if interceptor != nil {
@@ -205,7 +212,7 @@ func generateServerStreamHandlerFunc(
                }
                // embed header in context so that user logic could process 
them via FromIncomingContext
                ctx = newIncomingContext(ctx, conn.RequestHeader())
-               return streamFunc(
+               err := streamFunc(
                        ctx,
                        &Request{
                                Msg:    req,
@@ -215,6 +222,13 @@ func generateServerStreamHandlerFunc(
                        },
                        &ServerStream{conn: conn},
                )
+               if err != nil {
+                       return err
+               }
+               if outgoingData := ExtractFromOutgoingContext(ctx); 
outgoingData != nil {
+                       mergeHeaders(conn.ResponseTrailer(), outgoingData)
+               }
+               return nil
        }
        if interceptor != nil {
                implementation = 
interceptor.WrapStreamingHandler(implementation)
@@ -253,10 +267,14 @@ func generateBidiStreamHandlerFunc(
        implementation := func(ctx context.Context, conn StreamingHandlerConn) 
error {
                // embed header in context so that user logic could process 
them via FromIncomingContext
                ctx = newIncomingContext(ctx, conn.RequestHeader())
-               return streamFunc(
-                       ctx,
-                       &BidiStream{conn: conn},
-               )
+               err := streamFunc(ctx, &BidiStream{conn: conn})
+               if err != nil {
+                       return err
+               }
+               if outgoingData := ExtractFromOutgoingContext(ctx); 
outgoingData != nil {
+                       mergeHeaders(conn.ResponseTrailer(), outgoingData)
+               }
+               return nil
        }
        if interceptor != nil {
                implementation = 
interceptor.WrapStreamingHandler(implementation)
diff --git a/protocol/triple/triple_protocol/header.go 
b/protocol/triple/triple_protocol/header.go
index 28618e0dc..791cf4a30 100644
--- a/protocol/triple/triple_protocol/header.go
+++ b/protocol/triple/triple_protocol/header.go
@@ -19,6 +19,7 @@ import (
        "encoding/base64"
        "fmt"
        "net/http"
+       "strings"
 )
 
 // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
@@ -88,20 +89,45 @@ func addHeaderCanonical(h http.Header, key, value string) {
        h[key] = append(h[key], value)
 }
 
-type headerIncomingKey struct{}
-type headerOutgoingKey struct{}
+type extraDataKey struct{}
+
+const headerIncomingKey string = "headerIncomingKey"
+const headerOutgoingKey string = "headerOutgoingKey"
+
 type handlerOutgoingKey struct{}
 
-func newIncomingContext(ctx context.Context, header http.Header) 
context.Context {
-       return context.WithValue(ctx, headerIncomingKey{}, header)
+func newIncomingContext(ctx context.Context, data http.Header) context.Context 
{
+       var header = http.Header{}
+       extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+       if !ok {
+               extraData = map[string]http.Header{}
+       }
+       if data != nil {
+               for key, vals := range data {
+                       header[strings.ToLower(key)] = vals
+               }
+       }
+       extraData[headerIncomingKey] = header
+       return context.WithValue(ctx, extraDataKey{}, extraData)
 }
 
 // NewOutgoingContext sets headers entirely. If there are existing headers, 
they would be replaced.
 // It is used for passing headers to server-side.
 // It is like grpc.NewOutgoingContext.
 // Please refer to 
https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata.
-func NewOutgoingContext(ctx context.Context, header http.Header) 
context.Context {
-       return context.WithValue(ctx, headerOutgoingKey{}, header)
+func NewOutgoingContext(ctx context.Context, data http.Header) context.Context 
{
+       var header = http.Header{}
+       if data != nil {
+               for key, vals := range data {
+                       header[strings.ToLower(key)] = vals
+               }
+       }
+       extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+       if !ok {
+               extraData = map[string]http.Header{}
+       }
+       extraData[headerOutgoingKey] = header
+       return context.WithValue(ctx, extraDataKey{}, extraData)
 }
 
 // AppendToOutgoingContext merges kv pairs from user and existing headers.
@@ -112,37 +138,47 @@ func AppendToOutgoingContext(ctx context.Context, kv 
...string) context.Context
        if len(kv)%2 == 1 {
                panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of 
input pairs for header: %d", len(kv)))
        }
-       var header http.Header
-       headerRaw := ctx.Value(headerOutgoingKey{})
-       if headerRaw == nil {
+       extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+       if !ok {
+               extraData = map[string]http.Header{}
+               ctx = context.WithValue(ctx, extraDataKey{}, extraData)
+       }
+       header, ok := extraData[headerOutgoingKey]
+       if !ok {
                header = make(http.Header)
-       } else {
-               header = headerRaw.(http.Header)
+               extraData[headerOutgoingKey] = header
        }
        for i := 0; i < len(kv); i += 2 {
                // todo(DMwangnima): think about lowering
-               header.Add(kv[i], kv[i+1])
+               header.Add(strings.ToLower(kv[i]), kv[i+1])
        }
-       return context.WithValue(ctx, headerOutgoingKey{}, header)
+       return ctx
 }
 
 func ExtractFromOutgoingContext(ctx context.Context) http.Header {
-       headerRaw := ctx.Value(headerOutgoingKey{})
-       if headerRaw == nil {
+       extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+       if !ok {
                return nil
        }
-       // since headerOutgoingKey is only used in triple_protocol package, we 
need not verify the type
-       return headerRaw.(http.Header)
+       if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok {
+               return nil
+       } else {
+               return outGoingDataHeader
+       }
 }
 
 // FromIncomingContext retrieves headers passed by client-side. It is like 
grpc.FromIncomingContext.
+// it must call after append/setOutgoingContext to return current value
 // Please refer to 
https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1.
 func FromIncomingContext(ctx context.Context) (http.Header, bool) {
-       header, ok := ctx.Value(headerIncomingKey{}).(http.Header)
+       data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
        if !ok {
                return nil, false
+       } else if incomingDataHeader, ok := data[headerIncomingKey]; !ok {
+               return nil, false
+       } else {
+               return incomingDataHeader, true
        }
-       return header, true
 }
 
 // SetHeader is used for setting response header in server-side. It is like 
grpc.SendHeader(ctx, header) but

Reply via email to