This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new b1002a7  THRIFT-4914: Add THeader to context for server reads
b1002a7 is described below

commit b1002a71fb3838039d3442420c604999551311e9
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Mon Aug 5 13:03:02 2019 -0700

    THRIFT-4914: Add THeader to context for server reads
    
    Client: go
    
    This is the first part of THRIFT-4914, which handles the server reading
    part in the requests (client -> server direction).
    
    In TSimpleServer, when the protocol is THeaderProtocol automatically
    add all present headers into the context object before passing
    it to processor, so the processor code can access headers from the
    context directly by using the new helper functions added in
    header_context.go.
    
    This closes #1840.
---
 lib/go/thrift/header_context.go        | 81 ++++++++++++++++++++++++++++
 lib/go/thrift/header_context_test.go   | 97 ++++++++++++++++++++++++++++++++++
 lib/go/thrift/header_protocol.go       |  5 ++
 lib/go/thrift/header_transport_test.go | 12 +++++
 lib/go/thrift/simple_server.go         | 19 ++++++-
 5 files changed, 212 insertions(+), 2 deletions(-)

diff --git a/lib/go/thrift/header_context.go b/lib/go/thrift/header_context.go
new file mode 100644
index 0000000..5d9104b
--- /dev/null
+++ b/lib/go/thrift/header_context.go
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "context"
+)
+
+// See https://godoc.org/context#WithValue on why do we need the unexported 
typedefs.
+type (
+       headerKey     string
+       headerKeyList int
+)
+
+// Values for headerKeyList.
+const (
+       headerKeyListRead headerKeyList = iota
+)
+
+// SetHeader sets a header in the context.
+func SetHeader(ctx context.Context, key, value string) context.Context {
+       return context.WithValue(
+               ctx,
+               headerKey(key),
+               value,
+       )
+}
+
+// GetHeader returns a value of the given header from the context.
+func GetHeader(ctx context.Context, key string) (value string, ok bool) {
+       if v := ctx.Value(headerKey(key)); v != nil {
+               value, ok = v.(string)
+       }
+       return
+}
+
+// SetReadHeaderList sets the key list of read THeaders in the context.
+func SetReadHeaderList(ctx context.Context, keys []string) context.Context {
+       return context.WithValue(
+               ctx,
+               headerKeyListRead,
+               keys,
+       )
+}
+
+// GetReadHeaderList returns the key list of read THeaders from the context.
+func GetReadHeaderList(ctx context.Context) []string {
+       if v := ctx.Value(headerKeyListRead); v != nil {
+               if value, ok := v.([]string); ok {
+                       return value
+               }
+       }
+       return nil
+}
+
+// AddReadTHeaderToContext adds the whole THeader headers into context.
+func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) 
context.Context {
+       keys := make([]string, 0, len(headers))
+       for key, value := range headers {
+               ctx = SetHeader(ctx, key, value)
+               keys = append(keys, key)
+       }
+       return SetReadHeaderList(ctx, keys)
+}
diff --git a/lib/go/thrift/header_context_test.go 
b/lib/go/thrift/header_context_test.go
new file mode 100644
index 0000000..33ac4ec
--- /dev/null
+++ b/lib/go/thrift/header_context_test.go
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "context"
+       "reflect"
+       "testing"
+)
+
+func TestSetGetHeader(t *testing.T) {
+       const (
+               key   = "foo"
+               value = "bar"
+       )
+       ctx := context.Background()
+
+       ctx = SetHeader(ctx, key, value)
+
+       checkGet := func(t *testing.T, ctx context.Context) {
+               t.Helper()
+               got, ok := GetHeader(ctx, key)
+               if !ok {
+                       t.Fatalf("Cannot get header %q back after setting it.", 
key)
+               }
+               if got != value {
+                       t.Fatalf("Header value expected %q, got %q instead", 
value, got)
+               }
+       }
+
+       checkGet(t, ctx)
+
+       t.Run(
+               "NoConflicts",
+               func(t *testing.T) {
+                       type otherType string
+                       const otherValue = "bar2"
+
+                       ctx = context.WithValue(ctx, otherType(key), otherValue)
+                       checkGet(t, ctx)
+               },
+       )
+
+       t.Run(
+               "GetHeaderOnNonExistKey",
+               func(t *testing.T) {
+                       const otherKey = "foo2"
+
+                       if _, ok := GetHeader(ctx, otherKey); ok {
+                               t.Errorf("GetHeader returned ok on non-existing 
key %q", otherKey)
+                       }
+               },
+       )
+}
+
+func TestKeyList(t *testing.T) {
+       headers := THeaderMap{
+               "key1": "value1",
+               "key2": "value2",
+       }
+       ctx := context.Background()
+
+       ctx = AddReadTHeaderToContext(ctx, headers)
+
+       got := make(THeaderMap)
+       keys := GetReadHeaderList(ctx)
+       t.Logf("keys: %+v", keys)
+       for _, key := range keys {
+               value, ok := GetHeader(ctx, key)
+               if ok {
+                       got[key] = value
+               } else {
+                       t.Errorf("Cannot get key %q from context", key)
+               }
+       }
+
+       if !reflect.DeepEqual(headers, got) {
+               t.Errorf("Expected header map %+v, got %+v", headers, got)
+       }
+}
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 0cf48f7..46205b2 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -188,6 +188,11 @@ func (p *THeaderProtocol) WriteBinary(value []byte) error {
        return p.protocol.WriteBinary(value)
 }
 
+// ReadFrame calls underlying THeaderTransport's ReadFrame function.
+func (p *THeaderProtocol) ReadFrame() error {
+       return p.transport.ReadFrame()
+}
+
 func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID 
TMessageType, seqID int32, err error) {
        if err = p.transport.ReadFrame(); err != nil {
                return
diff --git a/lib/go/thrift/header_transport_test.go 
b/lib/go/thrift/header_transport_test.go
index 7462dd5..e304768 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -21,6 +21,7 @@ package thrift
 
 import (
        "context"
+       "io"
        "io/ioutil"
        "testing"
 )
@@ -73,10 +74,21 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
        }
 
        // Read
+
+       // Make sure multiple calls to ReadFrame is fine.
+       if err := reader.ReadFrame(); err != nil {
+               t.Errorf("reader.ReadFrame returned error: %v", err)
+       }
+       if err := reader.ReadFrame(); err != nil {
+               t.Errorf("reader.ReadFrame returned error: %v", err)
+       }
        read, err := ioutil.ReadAll(reader)
        if err != nil {
                t.Errorf("Read returned error: %v", err)
        }
+       if err := reader.ReadFrame(); err != nil && err != io.EOF {
+               t.Errorf("reader.ReadFrame returned error: %v", err)
+       }
        if string(read) != payload1+payload2 {
                t.Errorf(
                        "Read content expected %q, got %q",
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 7db36c2..9155cfb 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -194,7 +194,8 @@ func (p *TSimpleServer) processRequests(client TTransport) 
error {
        // for THeaderProtocol, we must use the same protocol instance for
        // input and output so that the response is in the same dialect that
        // the server detected the request was in.
-       if _, ok := inputProtocol.(*THeaderProtocol); ok {
+       headerProtocol, ok := inputProtocol.(*THeaderProtocol)
+       if ok {
                outputProtocol = inputProtocol
        } else {
                oTrans, err := p.outputTransportFactory.GetTransport(client)
@@ -222,7 +223,21 @@ func (p *TSimpleServer) processRequests(client TTransport) 
error {
                        return nil
                }
 
-               ok, err := processor.Process(defaultCtx, inputProtocol, 
outputProtocol)
+               ctx := defaultCtx
+               if headerProtocol != nil {
+                       // We need to call ReadFrame here, otherwise we won't
+                       // get any headers on the AddReadTHeaderToContext call.
+                       //
+                       // ReadFrame is safe to be called multiple times so it
+                       // won't break when it's called again later when we
+                       // actually start to read the message.
+                       if err := headerProtocol.ReadFrame(); err != nil {
+                               return err
+                       }
+                       ctx = AddReadTHeaderToContext(defaultCtx, 
headerProtocol.GetReadHeaders())
+               }
+
+               ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
                if err, ok := err.(TTransportException); ok && err.TypeId() == 
END_OF_FILE {
                        return nil
                } else if err != nil {

Reply via email to