This is an automated email from the ASF dual-hosted git repository.

baze pushed a commit to branch 1.4
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git

commit bcb01004337eb1a3baf4390f905f649e4477ad85
Author: Patrick <dreamlike....@foxmail.com>
AuthorDate: Wed Apr 1 19:05:58 2020 +0800

    optimize header transmit in RestClient and RestServer
---
 protocol/rest/client/client_impl/resty_client.go   | 11 +++---
 protocol/rest/client/rest_client.go                |  5 +--
 protocol/rest/rest_invoker.go                      | 23 ++++++++---
 protocol/rest/server/rest_server.go                |  4 ++
 .../rest/server/server_impl/go_restful_server.go   | 44 +++++++++++++++++++++-
 5 files changed, 72 insertions(+), 15 deletions(-)

diff --git a/protocol/rest/client/client_impl/resty_client.go 
b/protocol/rest/client/client_impl/resty_client.go
index af9637e..9e0c80c 100644
--- a/protocol/rest/client/client_impl/resty_client.go
+++ b/protocol/rest/client/client_impl/resty_client.go
@@ -66,20 +66,19 @@ func NewRestyClient(restOption *client.RestOptions) 
client.RestClient {
 }
 
 func (rc *RestyClient) Do(restRequest *client.RestClientRequest, res 
interface{}) error {
-       r, err := rc.client.R().
-               SetHeader("Content-Type", restRequest.Consumes).
-               SetHeader("Accept", restRequest.Produces).
+       req := rc.client.R()
+       req.Header = restRequest.Header
+       resp, err := req.
                SetPathParams(restRequest.PathParams).
                SetQueryParams(restRequest.QueryParams).
-               SetHeaders(restRequest.Headers).
                SetBody(restRequest.Body).
                SetResult(res).
                Execute(restRequest.Method, 
"http://"+path.Join(restRequest.Location, restRequest.Path))
        if err != nil {
                return perrors.WithStack(err)
        }
-       if r.IsError() {
-               return perrors.New(r.String())
+       if resp.IsError() {
+               return perrors.New(resp.String())
        }
        return nil
 }
diff --git a/protocol/rest/client/rest_client.go 
b/protocol/rest/client/rest_client.go
index 3acccb5..5be4bb3 100644
--- a/protocol/rest/client/rest_client.go
+++ b/protocol/rest/client/rest_client.go
@@ -18,6 +18,7 @@
 package client
 
 import (
+       "net/http"
        "time"
 )
 
@@ -27,15 +28,13 @@ type RestOptions struct {
 }
 
 type RestClientRequest struct {
+       Header      http.Header
        Location    string
        Path        string
-       Produces    string
-       Consumes    string
        Method      string
        PathParams  map[string]string
        QueryParams map[string]string
        Body        interface{}
-       Headers     map[string]string
 }
 
 type RestClient interface {
diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go
index c8e3fea..121d121 100644
--- a/protocol/rest/rest_invoker.go
+++ b/protocol/rest/rest_invoker.go
@@ -20,6 +20,7 @@ package rest
 import (
        "context"
        "fmt"
+       "net/http"
 )
 
 import (
@@ -56,7 +57,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation 
protocol.Invocatio
                body        interface{}
                pathParams  map[string]string
                queryParams map[string]string
-               headers     map[string]string
+               header      http.Header
                err         error
        )
        if methodConfig == nil {
@@ -71,7 +72,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation 
protocol.Invocatio
                result.Err = err
                return &result
        }
-       if headers, err = restStringMapTransform(methodConfig.HeadersMap, 
inv.Arguments()); err != nil {
+       if header, err = getRestHttpHeader(methodConfig, inv.Arguments()); err 
!= nil {
                result.Err = err
                return &result
        }
@@ -80,14 +81,12 @@ func (ri *RestInvoker) Invoke(ctx context.Context, 
invocation protocol.Invocatio
        }
        req := &client.RestClientRequest{
                Location:    ri.GetUrl().Location,
-               Produces:    methodConfig.Produces,
-               Consumes:    methodConfig.Consumes,
                Method:      methodConfig.MethodType,
                Path:        methodConfig.Path,
                PathParams:  pathParams,
                QueryParams: queryParams,
                Body:        body,
-               Headers:     headers,
+               Header:      header,
        }
        result.Err = ri.client.Do(req, inv.Reply())
        if result.Err == nil {
@@ -106,3 +105,17 @@ func restStringMapTransform(paramsMap map[int]string, args 
[]interface{}) (map[s
        }
        return resMap, nil
 }
+
+func getRestHttpHeader(methodConfig *config.RestMethodConfig, args 
[]interface{}) (http.Header, error) {
+       header := http.Header{}
+       headersMap := methodConfig.HeadersMap
+       header.Set("Content-Type", methodConfig.Consumes)
+       header.Set("Accept", methodConfig.Produces)
+       for k, v := range headersMap {
+               if k >= len(args) || k < 0 {
+                       return nil, perrors.Errorf("[Rest Invoke] Index %v is 
out of bundle", k)
+               }
+               header.Set(v, fmt.Sprint(args[k]))
+       }
+       return header, nil
+}
diff --git a/protocol/rest/server/rest_server.go 
b/protocol/rest/server/rest_server.go
index b7eb555..7fb0560 100644
--- a/protocol/rest/server/rest_server.go
+++ b/protocol/rest/server/rest_server.go
@@ -46,6 +46,7 @@ type RestServer interface {
 
 // RestServerRequest interface
 type RestServerRequest interface {
+       RawRequest() *http.Request
        PathParameter(name string) string
        PathParameters() map[string]string
        QueryParameter(name string) string
@@ -57,6 +58,9 @@ type RestServerRequest interface {
 
 // RestServerResponse interface
 type RestServerResponse interface {
+       Header() http.Header
+       Write([]byte) (int, error)
+       WriteHeader(statusCode int)
        WriteError(httpStatus int, err error) (writeErr error)
        WriteEntity(value interface{}) error
 }
diff --git a/protocol/rest/server/server_impl/go_restful_server.go 
b/protocol/rest/server/server_impl/go_restful_server.go
index 81043c8..9163d3a 100644
--- a/protocol/rest/server/server_impl/go_restful_server.go
+++ b/protocol/rest/server/server_impl/go_restful_server.go
@@ -79,7 +79,7 @@ func (grs *GoRestfulServer) Start(url common.URL) {
 func (grs *GoRestfulServer) Deploy(restMethodConfig *config.RestMethodConfig, 
routeFunc func(request server.RestServerRequest, response 
server.RestServerResponse)) {
        ws := new(restful.WebService)
        rf := func(req *restful.Request, resp *restful.Response) {
-               routeFunc(req, resp)
+               routeFunc(NewGoRestfulRequestAdapter(req), resp)
        }
        ws.Path(restMethodConfig.Path).
                Produces(strings.Split(restMethodConfig.Produces, ",")...).
@@ -116,3 +116,45 @@ func GetNewGoRestfulServer() server.RestServer {
 func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) {
        filterSlice = append(filterSlice, filterFuc)
 }
+
+// Adapter about RestServerRequest
+type GoRestfulRequestAdapter struct {
+       server.RestServerRequest
+       request *restful.Request
+}
+
+func NewGoRestfulRequestAdapter(request *restful.Request) 
*GoRestfulRequestAdapter {
+       return &GoRestfulRequestAdapter{request: request}
+}
+
+func (grra *GoRestfulRequestAdapter) RawRequest() *http.Request {
+       return grra.request.Request
+}
+
+func (grra *GoRestfulRequestAdapter) PathParameter(name string) string {
+       return grra.request.PathParameter(name)
+}
+
+func (grra *GoRestfulRequestAdapter) PathParameters() map[string]string {
+       return grra.request.PathParameters()
+}
+
+func (grra *GoRestfulRequestAdapter) QueryParameter(name string) string {
+       return grra.request.QueryParameter(name)
+}
+
+func (grra *GoRestfulRequestAdapter) QueryParameters(name string) []string {
+       return grra.request.QueryParameters(name)
+}
+
+func (grra *GoRestfulRequestAdapter) BodyParameter(name string) (string, 
error) {
+       return grra.request.BodyParameter(name)
+}
+
+func (grra *GoRestfulRequestAdapter) HeaderParameter(name string) string {
+       return grra.request.HeaderParameter(name)
+}
+
+func (grra *GoRestfulRequestAdapter) ReadEntity(entityPointer interface{}) 
error {
+       return grra.request.ReadEntity(entityPointer)
+}

Reply via email to