This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new ee68159bc feat(triple): support generic call for Triple protocol
(#3154)
ee68159bc is described below
commit ee68159bc607e786ba267410edb5cd4a57b2b3a7
Author: Tsukikage <[email protected]>
AuthorDate: Thu Jan 15 18:02:26 2026 +0800
feat(triple): support generic call for Triple protocol (#3154)
Add generic call support for Triple protocol in non-IDL mode, enabling
Go clients to invoke Java services without pre-generated stubs.
Changes:
- Add protoWrapperCodec for wrapping hessian2 data in protobuf format
- Support protoBinaryCodec unmarshaling wrapped hessian2 responses
- Set CallType and parameterRawValues in generic filter for Triple
- Add NewGenericService API in client package
---
client/client.go | 30 +
client/options.go | 8 +
common/constant/key.go | 6 +-
filter/generic/filter.go | 52 +-
protocol/triple/client.go | 20 +-
protocol/triple/server.go | 122 ++--
protocol/triple/server_test.go | 57 ++
protocol/triple/triple.go | 25 +-
protocol/triple/triple_protocol/codec.go | 157 +++++-
.../triple/triple_protocol/codec_wrapper_test.go | 625 +++++++++++++++++++++
protocol/triple/triple_protocol/protocol_grpc.go | 2 +-
protocol/triple/triple_protocol/protocol_triple.go | 2 +-
protocol/triple/triple_test.go | 38 ++
server/server.go | 93 +++
server/server_test.go | 78 +++
15 files changed, 1247 insertions(+), 68 deletions(-)
diff --git a/client/client.go b/client/client.go
index 546f79ba0..074e5a1e1 100644
--- a/client/client.go
+++ b/client/client.go
@@ -26,6 +26,7 @@ import (
import (
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
+ "dubbo.apache.org/dubbo-go/v3/config/generic"
"dubbo.apache.org/dubbo-go/v3/metadata"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
@@ -130,6 +131,35 @@ func (cli *Client) NewService(service any, opts
...ReferenceOption) (*Connection
return cli.DialWithService(interfaceName, service, finalOpts...)
}
+// NewGenericService creates a GenericService for making generic calls without
pre-generated stubs.
+// The referenceStr parameter specifies the service interface name (e.g.,
"org.apache.dubbo.samples.UserProvider").
+//
+// Example usage:
+//
+// genericService, err :=
cli.NewGenericService("org.apache.dubbo.samples.UserProvider",
+// client.WithURL("tri://127.0.0.1:50052"),
+// )
+// if err != nil {
+// panic(err)
+// }
+// result, err := genericService.Invoke(ctx, "QueryUser",
[]string{"org.apache.dubbo.samples.User"}, []hessian.Object{user})
+func (cli *Client) NewGenericService(referenceStr string, opts
...ReferenceOption) (*generic.GenericService, error) {
+ finalOpts := []ReferenceOption{
+ WithIDL(constant.NONIDL),
+ WithGeneric(),
+ WithSerialization(constant.Hessian2Serialization),
+ }
+ finalOpts = append(finalOpts, opts...)
+
+ genericService := generic.NewGenericService(referenceStr)
+ _, err := cli.DialWithService(referenceStr, genericService,
finalOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return genericService, nil
+}
+
func (cli *Client) Dial(interfaceName string, opts ...ReferenceOption)
(*Connection, error) {
return cli.dial(interfaceName, nil, nil, opts...)
}
diff --git a/client/options.go b/client/options.go
index c1b9388de..f837b1501 100644
--- a/client/options.go
+++ b/client/options.go
@@ -382,6 +382,14 @@ func WithGeneric() ReferenceOption {
}
}
+// WithGenericType sets the generic serialization type for generic call
+// Valid values: "true" (default), "gson", "protobuf", "protobuf-json"
+func WithGenericType(genericType string) ReferenceOption {
+ return func(opts *ReferenceOptions) {
+ opts.Reference.Generic = genericType
+ }
+}
+
func WithSticky() ReferenceOption {
return func(opts *ReferenceOptions) {
opts.Reference.Sticky = true
diff --git a/common/constant/key.go b/common/constant/key.go
index a5312b7ff..2c6928492 100644
--- a/common/constant/key.go
+++ b/common/constant/key.go
@@ -413,8 +413,10 @@ const (
// Generic Filter
const (
- GenericSerializationDefault = "true"
- GenericSerializationGson = "gson"
+ GenericSerializationDefault = "true"
+ GenericSerializationGson = "gson"
+ GenericSerializationProtobuf = "protobuf"
+ GenericSerializationProtobufJson = "protobuf-json"
)
// AdaptiveService Filter
diff --git a/filter/generic/filter.go b/filter/generic/filter.go
index c3a0c653e..977f450c9 100644
--- a/filter/generic/filter.go
+++ b/filter/generic/filter.go
@@ -95,13 +95,59 @@ func (f *genericFilter) Invoke(ctx context.Context, invoker
base.Invoker, inv ba
types,
args,
}
- newIvc := invocation.NewRPCInvocation(constant.Generic,
newArgs, inv.Attachments())
- newIvc.SetReply(inv.Reply())
+
+ // For Triple protocol non-IDL mode, we need to set
parameterRawValues
+ // The format is [param1, param2, ..., paramN, reply] where the
last element is the reply placeholder
+ // Triple invoker slices as: request = inRaw[0:len-1], reply =
inRaw[len-1]
+ // So for generic call, we need [methodName, types, args,
reply] to get request = [methodName, types, args]
+ reply := inv.Reply()
+ parameterRawValues := []any{mtdName, types, args, reply}
+
+ newIvc := invocation.NewRPCInvocationWithOptions(
+ invocation.WithMethodName(constant.Generic),
+ invocation.WithArguments(newArgs),
+ invocation.WithParameterRawValues(parameterRawValues),
+ invocation.WithAttachments(inv.Attachments()),
+ invocation.WithReply(reply),
+ )
newIvc.Attachments()[constant.GenericKey] =
invoker.GetURL().GetParam(constant.GenericKey, "")
+ // Copy CallType attribute from original invocation for Triple
protocol support
+ // If not present, set default to CallUnary for generic calls
+ if callType, ok := inv.GetAttribute(constant.CallTypeKey); ok {
+ newIvc.SetAttribute(constant.CallTypeKey, callType)
+ } else {
+ newIvc.SetAttribute(constant.CallTypeKey,
constant.CallUnary)
+ }
+
return invoker.Invoke(ctx, newIvc)
} else if isMakingAGenericCall(invoker, inv) {
- inv.Attachments()[constant.GenericKey] =
invoker.GetURL().GetParam(constant.GenericKey, "")
+ // Arguments format: [methodName string, types []string, args
[]hessian.Object]
+ oldArgs := inv.Arguments()
+ reply := inv.Reply()
+
+ // For Triple protocol non-IDL mode, we need to set
parameterRawValues
+ // parameterRawValues format: [methodName, types, args, reply]
+ // Triple invoker slices as: request = inRaw[0:len-1], reply =
inRaw[len-1]
+ parameterRawValues := []any{oldArgs[0], oldArgs[1], oldArgs[2],
reply}
+
+ newIvc := invocation.NewRPCInvocationWithOptions(
+ invocation.WithMethodName(inv.MethodName()),
+ invocation.WithArguments(oldArgs),
+ invocation.WithParameterRawValues(parameterRawValues),
+ invocation.WithAttachments(inv.Attachments()),
+ invocation.WithReply(reply),
+ )
+ newIvc.Attachments()[constant.GenericKey] =
invoker.GetURL().GetParam(constant.GenericKey, "")
+
+ // Set CallType for Triple protocol support
+ if callType, ok := inv.GetAttribute(constant.CallTypeKey); ok {
+ newIvc.SetAttribute(constant.CallTypeKey, callType)
+ } else {
+ newIvc.SetAttribute(constant.CallTypeKey,
constant.CallUnary)
+ }
+
+ return invoker.Invoke(ctx, newIvc)
}
return invoker.Invoke(ctx, inv)
}
diff --git a/protocol/triple/client.go b/protocol/triple/client.go
index 3378a55e7..2dab272d2 100644
--- a/protocol/triple/client.go
+++ b/protocol/triple/client.go
@@ -284,7 +284,18 @@ func newClientManager(url *common.URL) (*clientManager,
error) {
triClients := make(map[string]*tri.Client)
- if len(url.Methods) != 0 {
+ // Check if this is a generic call - for generic call, we only need
$invoke method
+ generic := url.GetParam(constant.GenericKey, "")
+ isGeneric := isGenericCall(generic)
+
+ if isGeneric {
+ // For generic call, only register $invoke method
+ invokeURL, err := joinPath(baseTriURL, url.Interface(),
constant.Generic)
+ if err != nil {
+ return nil, fmt.Errorf("JoinPath failed for base %s,
interface %s, method %s", baseTriURL, url.Interface(), constant.Generic)
+ }
+ triClients[constant.Generic] = tri.NewClient(httpClient,
invokeURL, cliOpts...)
+ } else if len(url.Methods) != 0 {
for _, method := range url.Methods {
triURL, err := joinPath(baseTriURL, url.Interface(),
method)
if err != nil {
@@ -312,6 +323,13 @@ func newClientManager(url *common.URL) (*clientManager,
error) {
triClient := tri.NewClient(httpClient, triURL,
cliOpts...)
triClients[methodName] = triClient
}
+
+ // Register $invoke method for generic call support in non-IDL
mode
+ invokeURL, err := joinPath(baseTriURL, url.Interface(),
constant.Generic)
+ if err != nil {
+ return nil, fmt.Errorf("JoinPath failed for base %s,
interface %s, method %s", baseTriURL, url.Interface(), constant.Generic)
+ }
+ triClients[constant.Generic] = tri.NewClient(httpClient,
invokeURL, cliOpts...)
}
return &clientManager{
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index 419d9e83a..a8e63e98a 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -503,51 +503,101 @@ func createServiceInfoWithReflection(svc
common.RPCService) *common.ServiceInfo
if methodType.Name == "Reference" {
continue
}
- paramsNum := methodType.Type.NumIn()
- // the first param is receiver itself, the second param is ctx
- // just ignore them
- if paramsNum < 2 {
- logger.Fatalf("TRIPLE does not support %s method that
does not have any parameter", methodType.Name)
- continue
- }
- paramsTypes := make([]reflect.Type, paramsNum-2)
- for j := 2; j < paramsNum; j++ {
- paramsTypes[j-2] = methodType.Type.In(j)
- }
- methodInfo := common.MethodInfo{
- Name: methodType.Name,
- // only support Unary invocation now
- Type: constant.CallUnary,
- ReqInitFunc: func() any {
- params := make([]any, len(paramsTypes))
- for k, paramType := range paramsTypes {
- params[k] =
reflect.New(paramType).Interface()
- }
- return params
- },
+ methodInfo := buildMethodInfoWithReflection(methodType)
+ if methodInfo != nil {
+ methodInfos = append(methodInfos, *methodInfo)
}
- methodInfos = append(methodInfos, methodInfo)
}
- // only support no-idl mod call unary
- genericMethodInfo := common.MethodInfo{
- Name: "$invoke",
- Type: constant.CallUnary,
+ // Add $invoke method for generic call support
+ methodInfos = append(methodInfos, buildGenericMethodInfo())
+
+ info.Methods = methodInfos
+ return &info
+}
+
+// buildMethodInfoWithReflection creates MethodInfo for a single method using
reflection.
+func buildMethodInfoWithReflection(methodType reflect.Method)
*common.MethodInfo {
+ paramsNum := methodType.Type.NumIn()
+ // the first param is receiver itself, the second param is ctx
+ if paramsNum < 2 {
+ logger.Fatalf("TRIPLE does not support %s method that does not
have any parameter", methodType.Name)
+ return nil
+ }
+
+ // Extract parameter types (skip receiver and context)
+ paramsTypes := make([]reflect.Type, paramsNum-2)
+ for j := 2; j < paramsNum; j++ {
+ paramsTypes[j-2] = methodType.Type.In(j)
+ }
+
+ // Capture method for closure
+ method := methodType
+ return &common.MethodInfo{
+ Name: methodType.Name,
+ Type: constant.CallUnary, // only support Unary invocation now
ReqInitFunc: func() any {
- params := make([]any, 3)
- // params must be pointer
- params[0] = func(s string) *string { return &s
}("methodName") // methodName *string
- params[1] = &[]string{}
// argv type *[]string
- params[2] = &[]hessian.Object{}
// argv *[]hessian.Object
+ params := make([]any, len(paramsTypes))
+ for k, paramType := range paramsTypes {
+ params[k] = reflect.New(paramType).Interface()
+ }
return params
},
+ MethodFunc: func(ctx context.Context, args []any, handler any)
(any, error) {
+ in := []reflect.Value{reflect.ValueOf(handler)}
+ in = append(in, reflect.ValueOf(ctx))
+ for _, arg := range args {
+ in = append(in, reflect.ValueOf(arg))
+ }
+ returnValues := method.Func.Call(in)
+ if len(returnValues) == 1 {
+ if isReflectValueNil(returnValues[0]) {
+ return nil, nil
+ }
+ if err, ok :=
returnValues[0].Interface().(error); ok {
+ return nil, err
+ }
+ return nil, nil
+ }
+ var result any
+ var err error
+ if !isReflectValueNil(returnValues[0]) {
+ result = returnValues[0].Interface()
+ }
+ if len(returnValues) > 1 &&
!isReflectValueNil(returnValues[1]) {
+ if e, ok :=
returnValues[1].Interface().(error); ok {
+ err = e
+ }
+ }
+ return result, err
+ },
}
+}
- methodInfos = append(methodInfos, genericMethodInfo)
-
- info.Methods = methodInfos
+// buildGenericMethodInfo creates MethodInfo for $invoke generic call method.
+func buildGenericMethodInfo() common.MethodInfo {
+ return common.MethodInfo{
+ Name: constant.Generic,
+ Type: constant.CallUnary,
+ ReqInitFunc: func() any {
+ return []any{
+ func(s string) *string { return &s }(""), //
methodName *string
+ &[]string{}, //
types *[]string
+ &[]hessian.Object{}, //
args *[]hessian.Object
+ }
+ },
+ }
+}
- return &info
+// isReflectValueNil safely checks if a reflect.Value is nil.
+// It first checks if the value's kind supports nil checking to avoid panic.
+func isReflectValueNil(v reflect.Value) bool {
+ switch v.Kind() {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Pointer, reflect.Slice, reflect.UnsafePointer:
+ return v.IsNil()
+ default:
+ return false
+ }
}
// generateAttachments transfer http.Header to map[string]any and make all
keys lowercase
diff --git a/protocol/triple/server_test.go b/protocol/triple/server_test.go
index 32f1d8a4a..f6d801954 100644
--- a/protocol/triple/server_test.go
+++ b/protocol/triple/server_test.go
@@ -21,8 +21,10 @@ import (
"context"
"fmt"
"net/http"
+ "reflect"
"sync"
"testing"
+ "unsafe"
)
import (
@@ -485,3 +487,58 @@ func Test_createServiceInfoWithReflection(t *testing.T) {
assert.Len(t, paramsSlice, 3) // methodName, argv types, argv
})
}
+
+// Test isReflectValueNil safely checks if a reflect.Value is nil
+func Test_isReflectValueNil(t *testing.T) {
+ tests := []struct {
+ name string
+ value any
+ expected bool
+ }{
+ // nil nillable types
+ {"nil chan", (chan int)(nil), true},
+ {"nil map", (map[string]int)(nil), true},
+ {"nil slice", ([]int)(nil), true},
+ {"nil func", (func())(nil), true},
+ {"nil pointer", (*int)(nil), true},
+ {"nil unsafe.Pointer", unsafe.Pointer(nil), true},
+
+ // non-nil nillable types
+ {"non-nil chan", make(chan int), false},
+ {"non-nil map", map[string]int{"a": 1}, false},
+ {"non-nil slice", []int{1, 2, 3}, false},
+ {"non-nil func", func() {}, false},
+ {"non-nil pointer", new(int), false},
+
+ // non-nillable types (should return false, not panic)
+ {"int", 42, false},
+ {"string", "hello", false},
+ {"bool", true, false},
+ {"float64", 3.14, false},
+ {"struct", struct{ Name string }{"test"}, false},
+ {"array", [3]int{1, 2, 3}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ v := reflect.ValueOf(tt.value)
+ // should not panic
+ result := isReflectValueNil(v)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// Test isReflectValueNil with UnsafePointer specifically
+func Test_isReflectValueNil_UnsafePointer(t *testing.T) {
+ t.Run("nil unsafe.Pointer", func(t *testing.T) {
+ v := reflect.ValueOf(unsafe.Pointer(nil))
+ assert.True(t, isReflectValueNil(v))
+ })
+
+ t.Run("non-nil unsafe.Pointer", func(t *testing.T) {
+ x := 42
+ v := reflect.ValueOf(unsafe.Pointer(&x))
+ assert.False(t, isReflectValueNil(v))
+ })
+}
diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go
index 4e18ff6d1..731dcce41 100644
--- a/protocol/triple/triple.go
+++ b/protocol/triple/triple.go
@@ -19,6 +19,7 @@ package triple
import (
"context"
+ "strings"
"sync"
)
@@ -110,9 +111,16 @@ func (tp *TripleProtocol) Refer(url *common.URL)
base.Invoker {
IDLMode := url.GetParam(constant.IDLMode, "")
// for now, we do not need to use this info
_, ok := url.GetAttribute(constant.ClientInfoKey)
- // isIDL is NONIDL means new triple non-IDL mode
- if ok || IDLMode == constant.NONIDL {
- // stub code generated by new protoc-gen-go-triple
+ // Check if this is a generic call
+ generic := url.GetParam(constant.GenericKey, "")
+ isGenericCall := isGenericCall(generic)
+
+ // Use NewTripleInvoker for:
+ // 1. New protoc-gen-go-triple stub code (has ClientInfoKey)
+ // 2. Non-IDL mode (IDLMode == NONIDL)
+ // 3. Generic call (generic=true/gson/protobuf/protobuf-json)
+ if ok || IDLMode == constant.NONIDL || isGenericCall {
+ // new triple invoker supporting $invoke for generic calls
invoker, err = NewTripleInvoker(url)
} else {
// stub code generated by old protoc-gen-go-triple
@@ -141,6 +149,17 @@ func (tp *TripleProtocol) Destroy() {
tp.BaseProtocol.Destroy()
}
+// isGenericCall checks if the generic parameter indicates a generic call
+func isGenericCall(generic string) bool {
+ if generic == "" {
+ return false
+ }
+ return strings.EqualFold(generic, constant.GenericSerializationDefault)
||
+ strings.EqualFold(generic, constant.GenericSerializationGson) ||
+ strings.EqualFold(generic,
constant.GenericSerializationProtobuf) ||
+ strings.EqualFold(generic,
constant.GenericSerializationProtobufJson)
+}
+
func NewTripleProtocol() *TripleProtocol {
return &TripleProtocol{
BaseProtocol: base.NewBaseProtocol(),
diff --git a/protocol/triple/triple_protocol/codec.go
b/protocol/triple/triple_protocol/codec.go
index 0e3852162..9a3261e7a 100644
--- a/protocol/triple/triple_protocol/codec.go
+++ b/protocol/triple/triple_protocol/codec.go
@@ -97,6 +97,10 @@ type stableCodec interface {
IsBinary() bool
}
+// protoBinaryCodec handles standard protobuf binary serialization.
+// It also supports Java Dubbo Triple generic calls when the message is not a
proto.Message.
+// This dual functionality is needed because the server receives wrapped
generic calls
+// with Content-Type "application/proto", so this codec must handle both cases.
type protoBinaryCodec struct{}
var _ Codec = (*protoBinaryCodec)(nil)
@@ -114,11 +118,61 @@ func (c *protoBinaryCodec) Marshal(message any) ([]byte,
error) {
func (c *protoBinaryCodec) Unmarshal(data []byte, message any) error {
protoMessage, ok := message.(proto.Message)
if !ok {
- return errNotProto(message)
+ // Non-proto types indicate a generic call - try to unwrap from
wrapper format.
+ // This is used by the server when receiving Java/Go generic
calls.
+ return c.unmarshalWrappedMessage(data, message)
}
return proto.Unmarshal(data, protoMessage)
}
+// unmarshalWrappedMessage handles both TripleResponseWrapper and
TripleRequestWrapper formats.
+// It determines the format by checking if message is a slice (request) or not
(response).
+func (c *protoBinaryCodec) unmarshalWrappedMessage(data []byte, message any)
error {
+ hessianCodec := &hessian2Codec{}
+
+ // Check if message is a slice - if so, it's a request with multiple
args
+ if params, isSlice := message.([]any); isSlice {
+ // Request format: TripleRequestWrapper with multiple args
+ var reqWrapper interoperability.TripleRequestWrapper
+ if err := proto.Unmarshal(data, &reqWrapper); err != nil {
+ return fmt.Errorf("unmarshal wrapped request: %w", err)
+ }
+ if len(reqWrapper.Args) != len(params) {
+ return fmt.Errorf("unmarshal wrapped request: expected
%d params, got %d args", len(params), len(reqWrapper.Args))
+ }
+
+ for i, arg := range reqWrapper.Args {
+ if err := hessianCodec.Unmarshal(arg, params[i]); err
!= nil {
+ return fmt.Errorf("unmarshal wrapped request
arg[%d]: %w", i, err)
+ }
+ }
+ return nil
+ }
+
+ // Response format: TripleResponseWrapper with single data field
+ var respWrapper interoperability.TripleResponseWrapper
+ if err := proto.Unmarshal(data, &respWrapper); err == nil {
+ // Check if it's a valid response wrapper (has serializeType or
non-empty data)
+ if len(respWrapper.Data) > 0 {
+ return hessianCodec.Unmarshal(respWrapper.Data, message)
+ }
+ // Empty Data with serializeType indicates a null/void
response, which is valid
+ if respWrapper.SerializeType != "" {
+ return nil
+ }
+ }
+
+ // Fallback: try as single-arg request (not a response wrapper)
+ var reqWrapper interoperability.TripleRequestWrapper
+ if err := proto.Unmarshal(data, &reqWrapper); err != nil {
+ return fmt.Errorf("unmarshal wrapped message: %T is not a
proto.Message and data is not a valid wrapper", message)
+ }
+ if len(reqWrapper.Args) != 1 {
+ return fmt.Errorf("unmarshal wrapped message: expected 1 arg
for single param, got %d", len(reqWrapper.Args))
+ }
+ return hessianCodec.Unmarshal(reqWrapper.Args[0], message)
+}
+
func (c *protoBinaryCodec) MarshalStable(message any) ([]byte, error) {
protoMessage, ok := message.(proto.Message)
if !ok {
@@ -189,19 +243,59 @@ func (c *protoJSONCodec) IsBinary() bool {
return false
}
-// todo(DMwangnima): add unit tests
+// WrapperCodec is an interface for codecs that use a protobuf wrapper format
+// (TripleRequestWrapper/TripleResponseWrapper) on the wire. This is required
for
+// interoperability with Java Dubbo Triple protocol in non-IDL mode.
+//
+// Codecs implementing this interface:
+// - Use protobuf as the wire format (Content-Type: application/proto)
+// - Wrap data in TripleRequestWrapper (for requests) or TripleResponseWrapper
(for responses)
+// - Use an inner codec (e.g., hessian2) for the actual data serialization
+type WrapperCodec interface {
+ Codec
+ // WireCodecName returns "proto" because the wire format is protobuf.
+ WireCodecName() string
+}
+
+// getWireCodecName returns the codec name to use for Content-Type on the wire.
+// If the codec implements WrapperCodec, its WireCodecName() is used.
+// Otherwise, the codec's Name() is used.
+func getWireCodecName(codec Codec) string {
+ if wrapper, ok := codec.(WrapperCodec); ok {
+ return wrapper.WireCodecName()
+ }
+ return codec.Name()
+}
+
+// protoWrapperCodec wraps an inner codec (e.g., hessian2) in protobuf wrapper
format.
+// This is used for interoperability with Java Dubbo Triple protocol in
non-IDL mode.
+//
+// Wire format:
+// - Requests use TripleRequestWrapper (multiple args, argTypes)
+// - Responses use TripleResponseWrapper (single data field)
+//
+// The Content-Type is "application/proto" because the outer format is
protobuf.
+// The inner serialization type (e.g., "hessian2") is stored in the wrapper's
serializeType field.
type protoWrapperCodec struct {
innerCodec Codec
}
+var _ WrapperCodec = (*protoWrapperCodec)(nil)
+
+// Name returns the inner codec name (e.g., "hessian2") for codec registration
and lookup.
func (c *protoWrapperCodec) Name() string {
return c.innerCodec.Name()
}
+// WireCodecName returns "proto" because the wire format is protobuf.
+// This ensures the correct Content-Type (application/proto) is used.
+func (c *protoWrapperCodec) WireCodecName() string {
+ return codecNameProto
+}
+
+// Marshal wraps the message in TripleRequestWrapper format for requests.
func (c *protoWrapperCodec) Marshal(message any) ([]byte, error) {
- var reqs []any
- var ok bool
- reqs, ok = message.([]any)
+ reqs, ok := message.([]any)
if !ok {
reqs = []any{message}
}
@@ -227,29 +321,50 @@ func (c *protoWrapperCodec) Marshal(message any) ([]byte,
error) {
return proto.Marshal(wrapperReq)
}
+// Unmarshal handles both TripleResponseWrapper (for responses) and
TripleRequestWrapper (for requests).
+// It determines the format by checking if message is a slice (request) or not
(response).
func (c *protoWrapperCodec) Unmarshal(binary []byte, message any) error {
- var params []any
- var ok bool
- params, ok = message.([]any)
- if !ok {
- params = []any{message}
- }
+ // Check if message is a slice - if so, it's a request with multiple
args
+ if params, isSlice := message.([]any); isSlice {
+ // Request format: TripleRequestWrapper with multiple args
+ var wrapperReq interoperability.TripleRequestWrapper
+ if err := proto.Unmarshal(binary, &wrapperReq); err != nil {
+ return err
+ }
+ if len(wrapperReq.Args) != len(params) {
+ return fmt.Errorf("wrapper codec: expected %d params,
got %d args", len(params), len(wrapperReq.Args))
+ }
- var wrapperReq interoperability.TripleRequestWrapper
- if err := proto.Unmarshal(binary, &wrapperReq); err != nil {
- return err
- }
- if len(wrapperReq.Args) != len(params) {
- return fmt.Errorf("error, request params len is %d, but has %d
actually", len(wrapperReq.Args), len(params))
+ for i, arg := range wrapperReq.Args {
+ if err := c.innerCodec.Unmarshal(arg, params[i]); err
!= nil {
+ return err
+ }
+ }
+ return nil
}
- for i, arg := range wrapperReq.Args {
- if err := c.innerCodec.Unmarshal(arg, params[i]); err != nil {
- return err
+ // Response format: TripleResponseWrapper with single data field
+ var wrapperResp interoperability.TripleResponseWrapper
+ if err := proto.Unmarshal(binary, &wrapperResp); err == nil {
+ // Check if it's a valid response wrapper (has serializeType or
non-empty data)
+ if len(wrapperResp.Data) > 0 {
+ return c.innerCodec.Unmarshal(wrapperResp.Data, message)
+ }
+ // Empty Data with serializeType indicates a null/void
response, which is valid
+ if wrapperResp.SerializeType != "" {
+ return nil
}
}
- return nil
+ // Fallback: try as single-arg request (not a response wrapper)
+ var wrapperReq interoperability.TripleRequestWrapper
+ if err := proto.Unmarshal(binary, &wrapperReq); err != nil {
+ return fmt.Errorf("wrapper codec: failed to unmarshal as
request or response wrapper")
+ }
+ if len(wrapperReq.Args) != 1 {
+ return fmt.Errorf("wrapper codec: expected 1 arg for single
param, got %d", len(wrapperReq.Args))
+ }
+ return c.innerCodec.Unmarshal(wrapperReq.Args[0], message)
}
func newProtoWrapperCodec(innerCodec Codec) *protoWrapperCodec {
diff --git a/protocol/triple/triple_protocol/codec_wrapper_test.go
b/protocol/triple/triple_protocol/codec_wrapper_test.go
new file mode 100644
index 000000000..a3bb46df6
--- /dev/null
+++ b/protocol/triple/triple_protocol/codec_wrapper_test.go
@@ -0,0 +1,625 @@
+// Copyright 2021-2023 Buf Technologies, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package triple_protocol
+
+import (
+ "testing"
+ "time"
+)
+
+import (
+ hessian "github.com/apache/dubbo-go-hessian2"
+
+ "google.golang.org/protobuf/proto"
+)
+
+import (
+
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/assert"
+
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/interoperability"
+)
+
+// TestUser is a test POJO for hessian2 serialization
+type TestUser struct {
+ ID string
+ Name string
+ Age int32
+}
+
+func (u *TestUser) JavaClassName() string {
+ return "org.apache.dubbo.samples.User"
+}
+
+func init() {
+ hessian.RegisterPOJO(&TestUser{})
+}
+
+//
=============================================================================
+// protoWrapperCodec Tests
+//
=============================================================================
+
+func TestProtoWrapperCodec_Name(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+ assert.Equal(t, codec.Name(), codecNameHessian2)
+}
+
+func TestProtoWrapperCodec_WireCodecName(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+ assert.Equal(t, codec.WireCodecName(), codecNameProto)
+}
+
+func TestProtoWrapperCodec_ImplementsWrapperCodec(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+ var _ WrapperCodec = codec // Compile-time check
+}
+
+func TestProtoWrapperCodec_MarshalRequest_SingleArg(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Marshal a single string argument
+ data, err := codec.Marshal([]any{"hello"})
+ assert.Nil(t, err)
+ assert.True(t, len(data) > 0)
+
+ // Verify it's a valid TripleRequestWrapper
+ var wrapper interoperability.TripleRequestWrapper
+ err = proto.Unmarshal(data, &wrapper)
+ assert.Nil(t, err)
+ assert.Equal(t, wrapper.SerializeType, codecNameHessian2)
+ assert.Equal(t, len(wrapper.Args), 1)
+ assert.Equal(t, len(wrapper.ArgTypes), 1)
+ assert.Equal(t, wrapper.ArgTypes[0], "java.lang.String")
+}
+
+func TestProtoWrapperCodec_MarshalRequest_MultipleArgs(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Marshal multiple arguments
+ data, err := codec.Marshal([]any{"hello", int32(42), true})
+ assert.Nil(t, err)
+
+ var wrapper interoperability.TripleRequestWrapper
+ err = proto.Unmarshal(data, &wrapper)
+ assert.Nil(t, err)
+ assert.Equal(t, len(wrapper.Args), 3)
+ assert.Equal(t, wrapper.ArgTypes[0], "java.lang.String")
+ assert.Equal(t, wrapper.ArgTypes[1], "int")
+ assert.Equal(t, wrapper.ArgTypes[2], "boolean")
+}
+
+func TestProtoWrapperCodec_MarshalRequest_POJO(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ user := &TestUser{ID: "001", Name: "test", Age: 25}
+ data, err := codec.Marshal([]any{user})
+ assert.Nil(t, err)
+
+ var wrapper interoperability.TripleRequestWrapper
+ err = proto.Unmarshal(data, &wrapper)
+ assert.Nil(t, err)
+ assert.Equal(t, len(wrapper.Args), 1)
+ assert.Equal(t, wrapper.ArgTypes[0], "org.apache.dubbo.samples.User")
+}
+
+func TestProtoWrapperCodec_UnmarshalRequest(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Create a TripleRequestWrapper
+ hessianCodec := &hessian2Codec{}
+ arg1, _ := hessianCodec.Marshal("hello")
+ arg2, _ := hessianCodec.Marshal(int32(42))
+
+ wrapper := &interoperability.TripleRequestWrapper{
+ SerializeType: codecNameHessian2,
+ Args: [][]byte{arg1, arg2},
+ ArgTypes: []string{"java.lang.String", "int"},
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Unmarshal - use interface pointers that hessian2 can fill
+ results := make([]any, 2)
+ for i := range results {
+ var v any
+ results[i] = &v
+ }
+ err := codec.Unmarshal(data, results)
+ assert.Nil(t, err)
+
+ // Verify the unmarshaled values
+ val0 := *(results[0].(*any))
+ val1 := *(results[1].(*any))
+ assert.Equal(t, val0, "hello")
+ assert.Equal(t, val1, int32(42))
+}
+
+func TestProtoWrapperCodec_UnmarshalResponse(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Create a TripleResponseWrapper
+ hessianCodec := &hessian2Codec{}
+ respData, _ := hessianCodec.Marshal(map[string]any{
+ "id": "001",
+ "name": "test",
+ "age": 25,
+ })
+
+ wrapper := &interoperability.TripleResponseWrapper{
+ SerializeType: codecNameHessian2,
+ Data: respData,
+ Type: "java.util.Map",
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Unmarshal
+ var result any
+ err := codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.NotNil(t, result)
+
+ resultMap, ok := result.(map[any]any)
+ assert.True(t, ok)
+ assert.Equal(t, resultMap["id"], "001")
+ assert.Equal(t, resultMap["name"], "test")
+}
+
+func TestProtoWrapperCodec_RoundTrip_Request(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Marshal
+ original := []any{"hello", int32(42)}
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ // Unmarshal into request format (simulating server receiving)
+ var str string
+ var num int32
+ params := []any{&str, &num}
+
+ // First parse as TripleRequestWrapper to verify format
+ var wrapper interoperability.TripleRequestWrapper
+ err = proto.Unmarshal(data, &wrapper)
+ assert.Nil(t, err)
+
+ // Now unmarshal the actual data
+ hessianCodec := &hessian2Codec{}
+ err = hessianCodec.Unmarshal(wrapper.Args[0], &str)
+ assert.Nil(t, err)
+ err = hessianCodec.Unmarshal(wrapper.Args[1], &num)
+ assert.Nil(t, err)
+
+ assert.Equal(t, str, "hello")
+ assert.Equal(t, num, int32(42))
+
+ _ = params // suppress unused warning
+}
+
+//
=============================================================================
+// protoBinaryCodec Wrapper Tests
+//
=============================================================================
+
+func TestProtoBinaryCodec_MarshalNonProtoReturnsError(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+
+ // Marshal a non-proto message should return error
+ result := map[string]any{"id": "001", "name": "test"}
+ _, err := codec.Marshal(result)
+ assert.NotNil(t, err)
+}
+
+func TestProtoBinaryCodec_UnmarshalWrappedResponse(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+
+ // Create a TripleResponseWrapper
+ hessianCodec := &hessian2Codec{}
+ respData, _ := hessianCodec.Marshal("hello world")
+
+ wrapper := &interoperability.TripleResponseWrapper{
+ SerializeType: codecNameHessian2,
+ Data: respData,
+ Type: "java.lang.String",
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Unmarshal
+ var result any
+ err := codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.Equal(t, result, "hello world")
+}
+
+func TestProtoBinaryCodec_UnmarshalWrappedRequest(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+
+ // Create a TripleRequestWrapper
+ hessianCodec := &hessian2Codec{}
+ arg1, _ := hessianCodec.Marshal("arg1")
+ arg2, _ := hessianCodec.Marshal(int64(123))
+
+ wrapper := &interoperability.TripleRequestWrapper{
+ SerializeType: codecNameHessian2,
+ Args: [][]byte{arg1, arg2},
+ ArgTypes: []string{"java.lang.String", "long"},
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Unmarshal - use interface pointers that hessian2 can fill
+ results := make([]any, 2)
+ for i := range results {
+ var v any
+ results[i] = &v
+ }
+ err := codec.Unmarshal(data, results)
+ assert.Nil(t, err)
+
+ // Verify the unmarshaled values
+ val0 := *(results[0].(*any))
+ val1 := *(results[1].(*any))
+ assert.Equal(t, val0, "arg1")
+ assert.Equal(t, val1, int64(123))
+}
+
+func TestProtoBinaryCodec_ResponseThenRequestFallback(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+
+ // Test that it tries TripleResponseWrapper first, then falls back to
TripleRequestWrapper
+ // Create a valid TripleRequestWrapper
+ hessianCodec := &hessian2Codec{}
+ arg1, _ := hessianCodec.Marshal("test")
+
+ wrapper := &interoperability.TripleRequestWrapper{
+ SerializeType: codecNameHessian2,
+ Args: [][]byte{arg1},
+ ArgTypes: []string{"java.lang.String"},
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Should successfully unmarshal as request (after response fallback)
+ // Use interface pointer that hessian2 can fill
+ results := make([]any, 1)
+ var v any
+ results[0] = &v
+ err := codec.Unmarshal(data, results)
+ assert.Nil(t, err)
+ assert.Equal(t, *(results[0].(*any)), "test")
+}
+
+//
=============================================================================
+// WrapperCodec Interface Tests
+//
=============================================================================
+
+func TestGetWireCodecName_WrapperCodec(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+ wireCodecName := getWireCodecName(codec)
+ assert.Equal(t, wireCodecName, codecNameProto)
+}
+
+func TestGetWireCodecName_RegularCodec(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+ wireCodecName := getWireCodecName(codec)
+ assert.Equal(t, wireCodecName, codecNameProto)
+}
+
+func TestGetWireCodecName_Hessian2Codec(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+ wireCodecName := getWireCodecName(codec)
+ assert.Equal(t, wireCodecName, codecNameHessian2)
+}
+
+//
=============================================================================
+// hessian2Codec Tests
+//
=============================================================================
+
+func TestHessian2Codec_Name(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+ assert.Equal(t, codec.Name(), codecNameHessian2)
+}
+
+func TestHessian2Codec_RoundTrip_String(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+
+ original := "hello world"
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ var result string
+ err = codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.Equal(t, result, original)
+}
+
+func TestHessian2Codec_RoundTrip_Int(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+
+ original := int32(12345)
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ var result int32
+ err = codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.Equal(t, result, original)
+}
+
+func TestHessian2Codec_RoundTrip_Map(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+
+ original := map[string]any{"key1": "value1", "key2": int64(42)}
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ var result any
+ err = codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+
+ resultMap, ok := result.(map[any]any)
+ assert.True(t, ok)
+ assert.Equal(t, resultMap["key1"], "value1")
+ assert.Equal(t, resultMap["key2"], int64(42))
+}
+
+func TestHessian2Codec_RoundTrip_Slice(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+
+ original := []string{"a", "b", "c"}
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ var result any
+ err = codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.NotNil(t, result)
+}
+
+func TestHessian2Codec_RoundTrip_POJO(t *testing.T) {
+ t.Parallel()
+
+ codec := &hessian2Codec{}
+
+ original := &TestUser{ID: "001", Name: "test", Age: 25}
+ data, err := codec.Marshal(original)
+ assert.Nil(t, err)
+
+ var result any
+ err = codec.Unmarshal(data, &result)
+ assert.Nil(t, err)
+ assert.NotNil(t, result)
+}
+
+//
=============================================================================
+// getArgType Tests
+//
=============================================================================
+
+func TestGetArgType_Nil(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(nil), "V")
+}
+
+func TestGetArgType_Bool(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(true), "boolean")
+ assert.Equal(t, getArgType(false), "boolean")
+}
+
+func TestGetArgType_BoolSlice(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType([]bool{true, false}), "[Z")
+}
+
+func TestGetArgType_Byte(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(byte(1)), "byte")
+}
+
+func TestGetArgType_ByteSlice(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType([]byte{1, 2, 3}), "[B")
+}
+
+func TestGetArgType_Int8(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(int8(1)), "byte")
+}
+
+func TestGetArgType_Int16(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(int16(1)), "short")
+}
+
+func TestGetArgType_Int32(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(int32(1)), "int")
+}
+
+func TestGetArgType_Int64(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(int64(1)), "long")
+}
+
+func TestGetArgType_Int(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(int(1)), "long")
+}
+
+func TestGetArgType_Float32(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(float32(1.0)), "float")
+}
+
+func TestGetArgType_Float64(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(float64(1.0)), "double")
+}
+
+func TestGetArgType_String(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType("hello"), "java.lang.String")
+}
+
+func TestGetArgType_StringSlice(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType([]string{"a", "b"}), "[Ljava.lang.String;")
+}
+
+func TestGetArgType_Time(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(time.Now()), "java.util.Date")
+}
+
+func TestGetArgType_Map(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, getArgType(map[any]any{}), "java.util.Map")
+ assert.Equal(t, getArgType(map[string]int{}), "java.util.Map")
+}
+
+func TestGetArgType_Slice(t *testing.T) {
+ t.Parallel()
+ // []int maps to [J (Java long array) because Go's int is 64-bit
+ assert.Equal(t, getArgType([]int{1, 2, 3}), "[J")
+ assert.Equal(t, getArgType([]int32{1, 2, 3}), "[I")
+ assert.Equal(t, getArgType([]int64{1, 2, 3}), "[J")
+ assert.Equal(t, getArgType([]float64{1.0, 2.0}), "[D")
+}
+
+func TestGetArgType_POJO(t *testing.T) {
+ t.Parallel()
+ user := &TestUser{ID: "001", Name: "test", Age: 25}
+ assert.Equal(t, getArgType(user), "org.apache.dubbo.samples.User")
+}
+
+//
=============================================================================
+// Edge Cases and Error Handling
+//
=============================================================================
+
+func TestProtoWrapperCodec_UnmarshalRequest_ArgCountMismatch(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Create a wrapper with 2 args
+ hessianCodec := &hessian2Codec{}
+ arg1, _ := hessianCodec.Marshal("hello")
+ arg2, _ := hessianCodec.Marshal(int32(42))
+
+ wrapper := &interoperability.TripleRequestWrapper{
+ SerializeType: codecNameHessian2,
+ Args: [][]byte{arg1, arg2},
+ ArgTypes: []string{"java.lang.String", "int"},
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Try to unmarshal into 1 param (mismatch)
+ var str string
+ err := codec.Unmarshal(data, []any{&str})
+ assert.NotNil(t, err)
+}
+
+func TestProtoBinaryCodec_Unmarshal_InvalidData(t *testing.T) {
+ t.Parallel()
+
+ codec := &protoBinaryCodec{}
+
+ // Try to unmarshal invalid data into a non-proto type
+ invalidData := []byte{0x01, 0x02, 0x03}
+ var result any
+ err := codec.Unmarshal(invalidData, &result)
+ // Should fail because it can't parse as either wrapper
+ assert.NotNil(t, err)
+}
+
+func TestProtoWrapperCodec_Marshal_EmptyArgs(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Marshal empty args (for no-arg methods)
+ data, err := codec.Marshal([]any{})
+ assert.Nil(t, err)
+
+ var wrapper interoperability.TripleRequestWrapper
+ err = proto.Unmarshal(data, &wrapper)
+ assert.Nil(t, err)
+ assert.Equal(t, len(wrapper.Args), 0)
+}
+
+func TestProtoWrapperCodec_Unmarshal_EmptyRequest(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&hessian2Codec{})
+
+ // Create an empty request wrapper
+ wrapper := &interoperability.TripleRequestWrapper{
+ SerializeType: codecNameHessian2,
+ Args: [][]byte{},
+ ArgTypes: []string{},
+ }
+ data, _ := proto.Marshal(wrapper)
+
+ // Unmarshal into empty params
+ err := codec.Unmarshal(data, []any{})
+ assert.Nil(t, err)
+}
+
+//
=============================================================================
+// Msgpack Wrapper Tests
+//
=============================================================================
+
+func TestProtoWrapperCodec_Msgpack(t *testing.T) {
+ t.Parallel()
+
+ codec := newProtoWrapperCodec(&msgpackCodec{})
+ assert.Equal(t, codec.Name(), codecNameMsgPack)
+ assert.Equal(t, codec.WireCodecName(), codecNameProto)
+}
diff --git a/protocol/triple/triple_protocol/protocol_grpc.go
b/protocol/triple/triple_protocol/protocol_grpc.go
index c49e5885b..153761b8e 100644
--- a/protocol/triple/triple_protocol/protocol_grpc.go
+++ b/protocol/triple/triple_protocol/protocol_grpc.go
@@ -249,7 +249,7 @@ func (g *grpcClient) WriteRequestHeader(_ StreamType,
header http.Header) {
if getHeaderCanonical(header, headerUserAgent) == "" {
header[headerUserAgent] = []string{defaultGrpcUserAgent}
}
- header[headerContentType] =
[]string{grpcContentTypeFromCodecName(g.Codec.Name())}
+ header[headerContentType] =
[]string{grpcContentTypeFromCodecName(getWireCodecName(g.Codec))}
// gRPC handles compression on a per-message basis, so we don't want to
// compress the whole stream. By default, http.Client will ask the
server
// to gzip the stream if we don't set Accept-Encoding.
diff --git a/protocol/triple/triple_protocol/protocol_triple.go
b/protocol/triple/triple_protocol/protocol_triple.go
index c9532e553..a3742b4e2 100644
--- a/protocol/triple/triple_protocol/protocol_triple.go
+++ b/protocol/triple/triple_protocol/protocol_triple.go
@@ -238,7 +238,7 @@ func (c *tripleClient) WriteRequestHeader(streamType
StreamType, header http.Hea
}
header[tripleHeaderProtocolVersion] = []string{tripleProtocolVersion}
header[headerContentType] = []string{
- tripleContentTypeFromCodecName(streamType, c.Codec.Name()),
+ tripleContentTypeFromCodecName(streamType,
getWireCodecName(c.Codec)),
}
if acceptCompression := c.CompressionPools.CommaSeparatedNames();
acceptCompression != "" {
header[tripleUnaryHeaderAcceptCompression] =
[]string{acceptCompression}
diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go
index 80c73aa21..3fa918a79 100644
--- a/protocol/triple/triple_test.go
+++ b/protocol/triple/triple_test.go
@@ -67,3 +67,41 @@ func TestTripleProtocol_Destroy_EmptyServerMap(t *testing.T)
{
tp.Destroy()
})
}
+
+// Test isGenericCall checks if the generic parameter indicates a generic call
+func Test_isGenericCall(t *testing.T) {
+ tests := []struct {
+ name string
+ generic string
+ expected bool
+ }{
+ // valid generic serialization types
+ {"empty string", "", false},
+ {"true", "true", true},
+ {"TRUE", "TRUE", true},
+ {"True", "True", true},
+ {"gson", "gson", true},
+ {"GSON", "GSON", true},
+ {"Gson", "Gson", true},
+ {"protobuf", "protobuf", true},
+ {"PROTOBUF", "PROTOBUF", true},
+ {"Protobuf", "Protobuf", true},
+ {"protobuf-json", "protobuf-json", true},
+ {"PROTOBUF-JSON", "PROTOBUF-JSON", true},
+ {"Protobuf-Json", "Protobuf-Json", true},
+
+ // invalid generic serialization types
+ {"false", "false", false},
+ {"random", "random", false},
+ {"json", "json", false},
+ {"xml", "xml", false},
+ {"hessian", "hessian", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := isGenericCall(tt.generic)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/server/server.go b/server/server.go
index f2e19bbe2..271db85f4 100644
--- a/server/server.go
+++ b/server/server.go
@@ -19,8 +19,11 @@
package server
import (
+ "context"
+ "reflect"
"sort"
"strconv"
+ "strings"
"sync"
)
@@ -176,6 +179,61 @@ func (s *Server) genSvcOpts(handler any, info
*common.ServiceInfo, opts ...Servi
return newSvcOpts, nil
}
+// isNillable checks if a reflect.Value's kind supports nil checking.
+func isNillable(v reflect.Value) bool {
+ switch v.Kind() {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Pointer, reflect.Slice, reflect.UnsafePointer:
+ return true
+ default:
+ return false
+ }
+}
+
+// isReflectNil safely checks if a reflect.Value is nil.
+func isReflectNil(v reflect.Value) bool {
+ return isNillable(v) && v.IsNil()
+}
+
+// CallMethodByReflection invokes the given method via reflection and
processes its return values.
+// This is a shared helper function used by both server/server.go and
protocol/triple/server.go.
+func CallMethodByReflection(ctx context.Context, method reflect.Method,
handler any, args []any) (any, error) {
+ in := []reflect.Value{reflect.ValueOf(handler)}
+ in = append(in, reflect.ValueOf(ctx))
+ for _, arg := range args {
+ in = append(in, reflect.ValueOf(arg))
+ }
+ returnValues := method.Func.Call(in)
+
+ // Process return values
+ if len(returnValues) == 1 {
+ if isReflectNil(returnValues[0]) {
+ return nil, nil
+ }
+ if err, ok := returnValues[0].Interface().(error); ok {
+ return nil, err
+ }
+ return nil, nil
+ }
+ var result any
+ var err error
+ if !isReflectNil(returnValues[0]) {
+ result = returnValues[0].Interface()
+ }
+ if len(returnValues) > 1 && !isReflectNil(returnValues[1]) {
+ if e, ok := returnValues[1].Interface().(error); ok {
+ err = e
+ }
+ }
+ return result, err
+}
+
+// createReflectionMethodFunc creates a MethodFunc that calls the given method
via reflection.
+func createReflectionMethodFunc(method reflect.Method) func(ctx
context.Context, args []any, handler any) (any, error) {
+ return func(ctx context.Context, args []any, handler any) (any, error) {
+ return CallMethodByReflection(ctx, method, handler, args)
+ }
+}
+
// Add a method with a name of a different first-letter case
// to achieve interoperability with java
// TODO: The method name case sensitivity in Dubbo-java should be addressed.
@@ -184,13 +242,48 @@ func enhanceServiceInfo(info *common.ServiceInfo)
*common.ServiceInfo {
if info == nil {
return info
}
+
+ // Get service type for reflection-based method calls
+ var svcType reflect.Type
+ if info.ServiceType != nil {
+ svcType = reflect.TypeOf(info.ServiceType)
+ }
+
+ // Build method map for reflection lookup
+ methodMap := make(map[string]reflect.Method)
+ if svcType != nil {
+ for i := 0; i < svcType.NumMethod(); i++ {
+ m := svcType.Method(i)
+ methodMap[m.Name] = m
+ methodMap[strings.ToLower(m.Name)] = m
+ }
+ }
+
+ // Add MethodFunc to methods that don't have it
+ for i := range info.Methods {
+ if info.Methods[i].MethodFunc == nil && svcType != nil {
+ if reflectMethod, ok :=
methodMap[info.Methods[i].Name]; ok {
+ info.Methods[i].MethodFunc =
createReflectionMethodFunc(reflectMethod)
+ }
+ }
+ }
+
+ // Create additional methods with swapped-case names for Java
interoperability
var additionalMethods []common.MethodInfo
for _, method := range info.Methods {
newMethod := method
newMethod.Name = dubboutil.SwapCaseFirstRune(method.Name)
+ if method.MethodFunc != nil {
+ newMethod.MethodFunc = method.MethodFunc
+ } else if svcType != nil {
+ if reflectMethod, ok :=
methodMap[dubboutil.SwapCaseFirstRune(method.Name)]; ok {
+ newMethod.MethodFunc =
createReflectionMethodFunc(reflectMethod)
+ }
+ }
additionalMethods = append(additionalMethods, newMethod)
}
info.Methods = append(info.Methods, additionalMethods...)
+
return info
}
diff --git a/server/server_test.go b/server/server_test.go
index b5a3093c2..a12daa36c 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -18,9 +18,11 @@
package server
import (
+ "reflect"
"strconv"
"sync"
"testing"
+ "unsafe"
)
import (
@@ -480,3 +482,79 @@ func TestNewServerWithCustomGroup(t *testing.T) {
assert.NotNil(t, svcOpts)
assert.Equal(t, "test", svcOpts.Service.Group)
}
+
+// Test isNillable checks if a reflect.Value's kind supports nil checking
+func TestIsNillable(t *testing.T) {
+ tests := []struct {
+ name string
+ value any
+ expected bool
+ }{
+ // nillable types
+ {"chan", make(chan int), true},
+ {"func", func() {}, true},
+ {"interface", (*error)(nil), true},
+ {"map", map[string]int{}, true},
+ {"pointer", new(int), true},
+ {"slice", []int{}, true},
+ {"unsafe.Pointer", unsafe.Pointer(nil), true},
+ {"nil chan", (chan int)(nil), true},
+ {"nil map", (map[string]int)(nil), true},
+ {"nil slice", ([]int)(nil), true},
+
+ // non-nillable types
+ {"int", 42, false},
+ {"string", "hello", false},
+ {"bool", true, false},
+ {"float64", 3.14, false},
+ {"struct", struct{}{}, false},
+ {"array", [3]int{1, 2, 3}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ v := reflect.ValueOf(tt.value)
+ result := isNillable(v)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// Test isReflectNil safely checks if a reflect.Value is nil
+func TestIsReflectNil(t *testing.T) {
+ tests := []struct {
+ name string
+ value any
+ expected bool
+ }{
+ // nil nillable types
+ {"nil chan", (chan int)(nil), true},
+ {"nil map", (map[string]int)(nil), true},
+ {"nil slice", ([]int)(nil), true},
+ {"nil func", (func())(nil), true},
+ {"nil pointer", (*int)(nil), true},
+
+ // non-nil nillable types
+ {"non-nil chan", make(chan int), false},
+ {"non-nil map", map[string]int{"a": 1}, false},
+ {"non-nil slice", []int{1, 2, 3}, false},
+ {"non-nil func", func() {}, false},
+ {"non-nil pointer", new(int), false},
+
+ // non-nillable types (should return false, not panic)
+ {"int", 42, false},
+ {"string", "hello", false},
+ {"bool", true, false},
+ {"float64", 3.14, false},
+ {"struct", struct{ Name string }{"test"}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ v := reflect.ValueOf(tt.value)
+ // should not panic
+ result := isReflectNil(v)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}