[ 
https://issues.apache.org/jira/browse/THRIFT-4621?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16599450#comment-16599450
 ] 

ASF GitHub Bot commented on THRIFT-4621:
----------------------------------------

nsuke closed pull request #1583: THRIFT-4621 Add THeader for Python
URL: https://github.com/apache/thrift/pull/1583
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py
index 41bcf353d8..0e8271dc19 100644
--- a/lib/py/src/compat.py
+++ b/lib/py/src/compat.py
@@ -29,6 +29,9 @@ def binary_to_str(bin_val):
     def str_to_binary(str_val):
         return str_val
 
+    def byte_index(bytes_val, i):
+        return ord(bytes_val[i])
+
 else:
 
     from io import BytesIO as BufferIO  # noqa
@@ -38,3 +41,6 @@ def binary_to_str(bin_val):
 
     def str_to_binary(str_val):
         return bytes(str_val, 'utf8')
+
+    def byte_index(bytes_val, i):
+        return bytes_val[i]
diff --git a/lib/py/src/protocol/THeaderProtocol.py 
b/lib/py/src/protocol/THeaderProtocol.py
new file mode 100644
index 0000000000..b27a749953
--- /dev/null
+++ b/lib/py/src/protocol/THeaderProtocol.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
+from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
+from thrift.protocol.TProtocol import TProtocolBase, TProtocolException
+from thrift.Thrift import TApplicationException, TMessageType
+from thrift.transport.THeaderTransport import THeaderTransport, 
THeaderSubprotocolID, THeaderClientType
+
+
+PROTOCOLS_BY_ID = {
+    THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
+    THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
+}
+
+
+class THeaderProtocol(TProtocolBase):
+    """A framed protocol with headers and payload transforms.
+
+    THeaderProtocol frames other Thrift protocols and adds support for optional
+    out-of-band headers. The currently supported subprotocols are
+    TBinaryProtocol and TCompactProtocol.
+
+    It's also possible to apply transforms to the encoded message payload. The
+    only transform currently supported is to gzip.
+
+    When used in a server, THeaderProtocol can accept messages from
+    non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
+    includes framed and unframed transports and both TBinaryProtocol and
+    TCompactProtocol. The server will respond in the appropriate dialect for
+    the connected client. HTTP clients are not currently supported.
+
+    THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
+    or TProcessPoolServer.
+
+    See doc/specs/HeaderFormat.md for details of the wire format.
+
+    """
+
+    def __init__(self, transport, allowed_client_types):
+        # much of the actual work for THeaderProtocol happens down in
+        # THeaderTransport since we need to do low-level shenanigans to detect
+        # if the client is sending us headers or one of the headerless formats
+        # we support. this wraps the real transport with the one that does all
+        # the magic.
+        if not isinstance(transport, THeaderTransport):
+            transport = THeaderTransport(transport, allowed_client_types)
+        super(THeaderProtocol, self).__init__(transport)
+        self._set_protocol()
+
+    def get_headers(self):
+        return self.trans.get_headers()
+
+    def set_header(self, key, value):
+        self.trans.set_header(key, value)
+
+    def clear_headers(self):
+        self.trans.clear_headers()
+
+    def add_transform(self, transform_id):
+        self.trans.add_transform(transform_id)
+
+    def writeMessageBegin(self, name, ttype, seqid):
+        self.trans.sequence_id = seqid
+        return self._protocol.writeMessageBegin(name, ttype, seqid)
+
+    def writeMessageEnd(self):
+        return self._protocol.writeMessageEnd()
+
+    def writeStructBegin(self, name):
+        return self._protocol.writeStructBegin(name)
+
+    def writeStructEnd(self):
+        return self._protocol.writeStructEnd()
+
+    def writeFieldBegin(self, name, ttype, fid):
+        return self._protocol.writeFieldBegin(name, ttype, fid)
+
+    def writeFieldEnd(self):
+        return self._protocol.writeFieldEnd()
+
+    def writeFieldStop(self):
+        return self._protocol.writeFieldStop()
+
+    def writeMapBegin(self, ktype, vtype, size):
+        return self._protocol.writeMapBegin(ktype, vtype, size)
+
+    def writeMapEnd(self):
+        return self._protocol.writeMapEnd()
+
+    def writeListBegin(self, etype, size):
+        return self._protocol.writeListBegin(etype, size)
+
+    def writeListEnd(self):
+        return self._protocol.writeListEnd()
+
+    def writeSetBegin(self, etype, size):
+        return self._protocol.writeSetBegin(etype, size)
+
+    def writeSetEnd(self):
+        return self._protocol.writeSetEnd()
+
+    def writeBool(self, bool_val):
+        return self._protocol.writeBool(bool_val)
+
+    def writeByte(self, byte):
+        return self._protocol.writeByte(byte)
+
+    def writeI16(self, i16):
+        return self._protocol.writeI16(i16)
+
+    def writeI32(self, i32):
+        return self._protocol.writeI32(i32)
+
+    def writeI64(self, i64):
+        return self._protocol.writeI64(i64)
+
+    def writeDouble(self, dub):
+        return self._protocol.writeDouble(dub)
+
+    def writeBinary(self, str_val):
+        return self._protocol.writeBinary(str_val)
+
+    def _set_protocol(self):
+        try:
+            protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
+        except KeyError:
+            raise TApplicationException(
+                TProtocolException.INVALID_PROTOCOL,
+                "Unknown protocol requested.",
+            )
+
+        self._protocol = protocol_cls(self.trans)
+        self._fast_encode = self._protocol._fast_encode
+        self._fast_decode = self._protocol._fast_decode
+
+    def readMessageBegin(self):
+        try:
+            self.trans.readFrame(0)
+            self._set_protocol()
+        except TApplicationException as exc:
+            self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
+            exc.write(self._protocol)
+            self._protocol.writeMessageEnd()
+            self.trans.flush()
+
+        return self._protocol.readMessageBegin()
+
+    def readMessageEnd(self):
+        return self._protocol.readMessageEnd()
+
+    def readStructBegin(self):
+        return self._protocol.readStructBegin()
+
+    def readStructEnd(self):
+        return self._protocol.readStructEnd()
+
+    def readFieldBegin(self):
+        return self._protocol.readFieldBegin()
+
+    def readFieldEnd(self):
+        return self._protocol.readFieldEnd()
+
+    def readMapBegin(self):
+        return self._protocol.readMapBegin()
+
+    def readMapEnd(self):
+        return self._protocol.readMapEnd()
+
+    def readListBegin(self):
+        return self._protocol.readListBegin()
+
+    def readListEnd(self):
+        return self._protocol.readListEnd()
+
+    def readSetBegin(self):
+        return self._protocol.readSetBegin()
+
+    def readSetEnd(self):
+        return self._protocol.readSetEnd()
+
+    def readBool(self):
+        return self._protocol.readBool()
+
+    def readByte(self):
+        return self._protocol.readByte()
+
+    def readI16(self):
+        return self._protocol.readI16()
+
+    def readI32(self):
+        return self._protocol.readI32()
+
+    def readI64(self):
+        return self._protocol.readI64()
+
+    def readDouble(self):
+        return self._protocol.readDouble()
+
+    def readBinary(self):
+        return self._protocol.readBinary()
+
+
+class THeaderProtocolFactory(object):
+    def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
+        self.allowed_client_types = allowed_client_types
+
+    def getProtocol(self, trans):
+        return THeaderProtocol(trans, self.allowed_client_types)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index fd20cb7906..8314cf69df 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -37,6 +37,7 @@ class TProtocolException(TException):
     BAD_VERSION = 4
     NOT_IMPLEMENTED = 5
     DEPTH_LIMIT = 6
+    INVALID_PROTOCOL = 7
 
     def __init__(self, type=UNKNOWN, message=None):
         TException.__init__(self, message)
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index d5d9c98a93..df2a7bb93d 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -23,6 +23,7 @@
 import threading
 
 from thrift.protocol import TBinaryProtocol
+from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
 from thrift.transport import TTransport
 
 logger = logging.getLogger(__name__)
@@ -60,6 +61,12 @@ def __initArgs__(self, processor, serverTransport,
         self.inputProtocolFactory = inputProtocolFactory
         self.outputProtocolFactory = outputProtocolFactory
 
+        input_is_header = isinstance(self.inputProtocolFactory, 
THeaderProtocolFactory)
+        output_is_header = isinstance(self.outputProtocolFactory, 
THeaderProtocolFactory)
+        if any((input_is_header, output_is_header)) and input_is_header != 
output_is_header:
+            raise ValueError("THeaderProtocol servers require that both the 
input and "
+                             "output protocols are THeaderProtocol.")
+
     def serve(self):
         pass
 
@@ -76,10 +83,20 @@ def serve(self):
             client = self.serverTransport.accept()
             if not client:
                 continue
+
             itrans = self.inputTransportFactory.getTransport(client)
-            otrans = self.outputTransportFactory.getTransport(client)
             iprot = self.inputProtocolFactory.getProtocol(itrans)
-            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+            # for THeaderProtocol, we must use the same protocol instance for
+            # input and output so that the response is in the same dialect that
+            # the server detected the request was in.
+            if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+                otrans = None
+                oprot = iprot
+            else:
+                otrans = self.outputTransportFactory.getTransport(client)
+                oprot = self.outputProtocolFactory.getProtocol(otrans)
+
             try:
                 while True:
                     self.processor.process(iprot, oprot)
@@ -89,7 +106,8 @@ def serve(self):
                 logger.exception(x)
 
             itrans.close()
-            otrans.close()
+            if otrans:
+                otrans.close()
 
 
 class TThreadedServer(TServer):
@@ -116,9 +134,18 @@ def serve(self):
 
     def handle(self, client):
         itrans = self.inputTransportFactory.getTransport(client)
-        otrans = self.outputTransportFactory.getTransport(client)
         iprot = self.inputProtocolFactory.getProtocol(itrans)
-        oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+        # for THeaderProtocol, we must use the same protocol instance for input
+        # and output so that the response is in the same dialect that the
+        # server detected the request was in.
+        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+            otrans = None
+            oprot = iprot
+        else:
+            otrans = self.outputTransportFactory.getTransport(client)
+            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
         try:
             while True:
                 self.processor.process(iprot, oprot)
@@ -128,7 +155,8 @@ def handle(self, client):
             logger.exception(x)
 
         itrans.close()
-        otrans.close()
+        if otrans:
+            otrans.close()
 
 
 class TThreadPoolServer(TServer):
@@ -156,9 +184,18 @@ def serveThread(self):
     def serveClient(self, client):
         """Process input/output from a client for as long as possible"""
         itrans = self.inputTransportFactory.getTransport(client)
-        otrans = self.outputTransportFactory.getTransport(client)
         iprot = self.inputProtocolFactory.getProtocol(itrans)
-        oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+        # for THeaderProtocol, we must use the same protocol instance for input
+        # and output so that the response is in the same dialect that the
+        # server detected the request was in.
+        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+            otrans = None
+            oprot = iprot
+        else:
+            otrans = self.outputTransportFactory.getTransport(client)
+            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
         try:
             while True:
                 self.processor.process(iprot, oprot)
@@ -168,7 +205,8 @@ def serveClient(self, client):
             logger.exception(x)
 
         itrans.close()
-        otrans.close()
+        if otrans:
+            otrans.close()
 
     def serve(self):
         """Start a fixed number of worker threads and put client into a 
queue"""
@@ -237,10 +275,18 @@ def try_close(file):
                     try_close(otrans)
                 else:
                     itrans = self.inputTransportFactory.getTransport(client)
-                    otrans = self.outputTransportFactory.getTransport(client)
-
                     iprot = self.inputProtocolFactory.getProtocol(itrans)
-                    oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+                    # for THeaderProtocol, we must use the same protocol
+                    # instance for input and output so that the response is in
+                    # the same dialect that the server detected the request was
+                    # in.
+                    if isinstance(self.inputProtocolFactory, 
THeaderProtocolFactory):
+                        otrans = None
+                        oprot = iprot
+                    else:
+                        otrans = 
self.outputTransportFactory.getTransport(client)
+                        oprot = self.outputProtocolFactory.getProtocol(otrans)
 
                     ecode = 0
                     try:
@@ -254,7 +300,8 @@ def try_close(file):
                             ecode = 1
                     finally:
                         try_close(itrans)
-                        try_close(otrans)
+                        if otrans:
+                            try_close(otrans)
 
                     os._exit(ecode)
 
diff --git a/lib/py/src/transport/THeaderTransport.py 
b/lib/py/src/transport/THeaderTransport.py
new file mode 100644
index 0000000000..c0d5640122
--- /dev/null
+++ b/lib/py/src/transport/THeaderTransport.py
@@ -0,0 +1,352 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import struct
+import zlib
+
+from thrift.compat import BufferIO, byte_index
+from thrift.protocol.TBinaryProtocol import TBinaryProtocol
+from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, 
writeVarint
+from thrift.Thrift import TApplicationException
+from thrift.transport.TTransport import (
+    CReadableTransport,
+    TMemoryBuffer,
+    TTransportBase,
+    TTransportException,
+)
+
+
+U16 = struct.Struct("!H")
+I32 = struct.Struct("!i")
+HEADER_MAGIC = 0x0FFF
+HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
+
+
+class THeaderClientType(object):
+    HEADERS = 0x00
+
+    FRAMED_BINARY = 0x01
+    UNFRAMED_BINARY = 0x02
+
+    FRAMED_COMPACT = 0x03
+    UNFRAMED_COMPACT = 0x04
+
+
+class THeaderSubprotocolID(object):
+    BINARY = 0x00
+    COMPACT = 0x02
+
+
+class TInfoHeaderType(object):
+    KEY_VALUE = 0x01
+
+
+class THeaderTransformID(object):
+    ZLIB = 0x01
+
+
+READ_TRANSFORMS_BY_ID = {
+    THeaderTransformID.ZLIB: zlib.decompress,
+}
+
+
+WRITE_TRANSFORMS_BY_ID = {
+    THeaderTransformID.ZLIB: zlib.compress,
+}
+
+
+def _readString(trans):
+    size = readVarint(trans)
+    if size < 0:
+        raise TTransportException(
+            TTransportException.NEGATIVE_SIZE,
+            "Negative length"
+        )
+    return trans.read(size)
+
+
+def _writeString(trans, value):
+    writeVarint(trans, len(value))
+    trans.write(value)
+
+
+class THeaderTransport(TTransportBase, CReadableTransport):
+    def __init__(self, transport, allowed_client_types):
+        self._transport = transport
+        self._client_type = THeaderClientType.HEADERS
+        self._allowed_client_types = allowed_client_types
+
+        self._read_buffer = BufferIO(b"")
+        self._read_headers = {}
+
+        self._write_buffer = BufferIO()
+        self._write_headers = {}
+        self._write_transforms = []
+
+        self.flags = 0
+        self.sequence_id = 0
+        self._protocol_id = THeaderSubprotocolID.BINARY
+        self._max_frame_size = HARD_MAX_FRAME_SIZE
+
+    def isOpen(self):
+        return self._transport.isOpen()
+
+    def open(self):
+        return self._transport.open()
+
+    def close(self):
+        return self._transport.close()
+
+    def get_headers(self):
+        return self._read_headers
+
+    def set_header(self, key, value):
+        if not isinstance(key, bytes):
+            raise ValueError("header names must be bytes")
+        if not isinstance(value, bytes):
+            raise ValueError("header values must be bytes")
+        self._write_headers[key] = value
+
+    def clear_headers(self):
+        self._write_headers.clear()
+
+    def add_transform(self, transform_id):
+        if transform_id not in WRITE_TRANSFORMS_BY_ID:
+            raise ValueError("unknown transform")
+        self._write_transforms.append(transform_id)
+
+    def set_max_frame_size(self, size):
+        if not 0 < size < HARD_MAX_FRAME_SIZE:
+            raise ValueError("maximum frame size should be < %d and > 0" % 
HARD_MAX_FRAME_SIZE)
+        self._max_frame_size = size
+
+    @property
+    def protocol_id(self):
+        if self._client_type == THeaderClientType.HEADERS:
+            return self._protocol_id
+        elif self._client_type in (THeaderClientType.FRAMED_BINARY, 
THeaderClientType.UNFRAMED_BINARY):
+            return THeaderSubprotocolID.BINARY
+        elif self._client_type in (THeaderClientType.FRAMED_COMPACT, 
THeaderClientType.UNFRAMED_COMPACT):
+            return THeaderSubprotocolID.COMPACT
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Protocol ID not know for client type %d" % self._client_type,
+            )
+
+    def read(self, sz):
+        # if there are bytes left in the buffer, produce those first.
+        bytes_read = self._read_buffer.read(sz)
+        bytes_left_to_read = sz - len(bytes_read)
+        if bytes_left_to_read == 0:
+            return bytes_read
+
+        # if we've determined this is an unframed client, just pass the read
+        # through to the underlying transport until we're reset again at the
+        # beginning of the next message.
+        if self._client_type in (THeaderClientType.UNFRAMED_BINARY, 
THeaderClientType.UNFRAMED_COMPACT):
+            return bytes_read + self._transport.read(bytes_left_to_read)
+
+        # we're empty and (maybe) framed. fill the buffers with the next frame.
+        self.readFrame(bytes_left_to_read)
+        return bytes_read + self._read_buffer.read(bytes_left_to_read)
+
+    def _set_client_type(self, client_type):
+        if client_type not in self._allowed_client_types:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Client type %d not allowed by server." % client_type,
+            )
+        self._client_type = client_type
+
+    def readFrame(self, req_sz):
+        # the first word could either be the length field of a framed message
+        # or the first bytes of an unframed message.
+        first_word = self._transport.readAll(I32.size)
+        frame_size, = I32.unpack(first_word)
+        is_unframed = False
+        if frame_size & TBinaryProtocol.VERSION_MASK == 
TBinaryProtocol.VERSION_1:
+            self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
+            is_unframed = True
+        elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
+              byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == 
TCompactProtocol.VERSION):
+            self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
+            is_unframed = True
+
+        if is_unframed:
+            bytes_left_to_read = req_sz - I32.size
+            if bytes_left_to_read > 0:
+                rest = self._transport.read(bytes_left_to_read)
+            else:
+                rest = b""
+            self._read_buffer = BufferIO(first_word + rest)
+            return
+
+        # ok, we're still here so we're framed.
+        if frame_size > self._max_frame_size:
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Frame was too large.",
+            )
+        read_buffer = BufferIO(self._transport.readAll(frame_size))
+
+        # the next word is either going to be the version field of a
+        # binary/compact protocol message or the magic value + flags of a
+        # header protocol message.
+        second_word = read_buffer.read(I32.size)
+        version, = I32.unpack(second_word)
+        read_buffer.seek(0)
+        if version >> 16 == HEADER_MAGIC:
+            self._set_client_type(THeaderClientType.HEADERS)
+            self._read_buffer = self._parse_header_format(read_buffer)
+        elif version & TBinaryProtocol.VERSION_MASK == 
TBinaryProtocol.VERSION_1:
+            self._set_client_type(THeaderClientType.FRAMED_BINARY)
+            self._read_buffer = read_buffer
+        elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
+              byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == 
TCompactProtocol.VERSION):
+            self._set_client_type(THeaderClientType.FRAMED_COMPACT)
+            self._read_buffer = read_buffer
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Could not detect client transport type.",
+            )
+
+    def _parse_header_format(self, buffer):
+        # make BufferIO look like TTransport for varint helpers
+        buffer_transport = TMemoryBuffer()
+        buffer_transport._buffer = buffer
+
+        buffer.read(2)  # discard the magic bytes
+        self.flags, = U16.unpack(buffer.read(U16.size))
+        self.sequence_id, = I32.unpack(buffer.read(I32.size))
+
+        header_length = U16.unpack(buffer.read(U16.size))[0] * 4
+        end_of_headers = buffer.tell() + header_length
+        if end_of_headers > len(buffer.getvalue()):
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Header size is larger than whole frame.",
+            )
+
+        self._protocol_id = readVarint(buffer_transport)
+
+        transforms = []
+        transform_count = readVarint(buffer_transport)
+        for _ in range(transform_count):
+            transform_id = readVarint(buffer_transport)
+            if transform_id not in READ_TRANSFORMS_BY_ID:
+                raise TApplicationException(
+                    TApplicationException.INVALID_TRANSFORM,
+                    "Unknown transform: %d" % transform_id,
+                )
+            transforms.append(transform_id)
+        transforms.reverse()
+
+        headers = {}
+        while buffer.tell() < end_of_headers:
+            header_type = readVarint(buffer_transport)
+            if header_type == TInfoHeaderType.KEY_VALUE:
+                count = readVarint(buffer_transport)
+                for _ in range(count):
+                    key = _readString(buffer_transport)
+                    value = _readString(buffer_transport)
+                    headers[key] = value
+            else:
+                break  # ignore unknown headers
+        self._read_headers = headers
+
+        # skip padding / anything we didn't understand
+        buffer.seek(end_of_headers)
+
+        payload = buffer.read()
+        for transform_id in transforms:
+            transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
+            payload = transform_fn(payload)
+        return BufferIO(payload)
+
+    def write(self, buf):
+        self._write_buffer.write(buf)
+
+    def flush(self):
+        payload = self._write_buffer.getvalue()
+        self._write_buffer = BufferIO()
+
+        buffer = BufferIO()
+        if self._client_type == THeaderClientType.HEADERS:
+            for transform_id in self._write_transforms:
+                transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
+                payload = transform_fn(payload)
+
+            headers = BufferIO()
+            writeVarint(headers, self._protocol_id)
+            writeVarint(headers, len(self._write_transforms))
+            for transform_id in self._write_transforms:
+                writeVarint(headers, transform_id)
+            if self._write_headers:
+                writeVarint(headers, TInfoHeaderType.KEY_VALUE)
+                writeVarint(headers, len(self._write_headers))
+                for key, value in self._write_headers.items():
+                    _writeString(headers, key)
+                    _writeString(headers, value)
+                self._write_headers = {}
+            padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
+            headers.write(b"\x00" * padding_needed)
+            header_bytes = headers.getvalue()
+
+            buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
+            buffer.write(U16.pack(HEADER_MAGIC))
+            buffer.write(U16.pack(self.flags))
+            buffer.write(I32.pack(self.sequence_id))
+            buffer.write(U16.pack(len(header_bytes) // 4))
+            buffer.write(header_bytes)
+            buffer.write(payload)
+        elif self._client_type in (THeaderClientType.FRAMED_BINARY, 
THeaderClientType.FRAMED_COMPACT):
+            buffer.write(I32.pack(len(payload)))
+            buffer.write(payload)
+        elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, 
THeaderClientType.UNFRAMED_COMPACT):
+            buffer.write(payload)
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Unknown client type.",
+            )
+
+        # the frame length field doesn't count towards the frame payload size
+        frame_bytes = buffer.getvalue()
+        frame_payload_size = len(frame_bytes) - 4
+        if frame_payload_size > self._max_frame_size:
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Attempting to send frame that is too large.",
+            )
+
+        self._transport.write(frame_bytes)
+        self._transport.flush()
+
+    @property
+    def cstringio_buf(self):
+        return self._read_buffer
+
+    def cstringio_refill(self, partialread, reqlen):
+        result = bytearray(partialread)
+        while len(result) < reqlen:
+            result += self.read(reqlen - len(result))
+        self._read_buffer = BufferIO(result)
+        return self._read_buffer
diff --git a/lib/py/src/transport/TTransport.py 
b/lib/py/src/transport/TTransport.py
index c8855ca7a9..d13060f810 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -32,6 +32,7 @@ class TTransportException(TException):
     END_OF_FILE = 4
     NEGATIVE_SIZE = 5
     SIZE_LIMIT = 6
+    INVALID_CLIENT_TYPE = 7
 
     def __init__(self, type=UNKNOWN, message=None):
         TException.__init__(self, message)
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index e5230854aa..9d6d54bcf3 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -83,6 +83,8 @@
   "cpp-py3_compact-accelc_http-ip-ssl",
   "cpp-py3_compact_http-ip",
   "cpp-py3_compact_http-ip-ssl",
+  "cpp-py3_header_http-ip",
+  "cpp-py3_header_http-ip-ssl",
   "cpp-py3_json_http-ip",
   "cpp-py3_json_http-ip-ssl",
   "cpp-py3_multi-accel_http-ip",
@@ -101,6 +103,8 @@
   "cpp-py3_multic-multiac_http-ip-ssl",
   "cpp-py3_multic_http-ip",
   "cpp-py3_multic_http-ip-ssl",
+  "cpp-py3_multih-header_http-ip",
+  "cpp-py3_multih-header_http-ip-ssl",
   "cpp-py3_multij-json_http-ip",
   "cpp-py3_multij-json_http-ip-ssl",
   "cpp-py3_multij_http-ip",
@@ -113,6 +117,8 @@
   "cpp-py_compact-accelc_http-ip-ssl",
   "cpp-py_compact_http-ip",
   "cpp-py_compact_http-ip-ssl",
+  "cpp-py_header_http-ip",
+  "cpp-py_header_http-ip-ssl",
   "cpp-py_json_http-ip",
   "cpp-py_json_http-ip-ssl",
   "cpp-py_multi-accel_http-ip",
@@ -131,6 +137,8 @@
   "cpp-py_multic-multiac_http-ip-ssl",
   "cpp-py_multic_http-ip",
   "cpp-py_multic_http-ip-ssl",
+  "cpp-py_multih-header_http-ip",
+  "cpp-py_multih-header_http-ip-ssl",
   "cpp-py_multij-json_http-ip",
   "cpp-py_multij-json_http-ip-ssl",
   "cpp-py_multij_http-ip",
@@ -375,6 +383,8 @@
   "py-cpp_binary_http-ip-ssl",
   "py-cpp_compact_http-ip",
   "py-cpp_compact_http-ip-ssl",
+  "py-cpp_header_http-ip",
+  "py-cpp_header_http-ip-ssl",
   "py-cpp_json_http-ip",
   "py-cpp_json_http-ip-ssl",
   "py-d_accel-binary_http-ip",
@@ -396,6 +406,7 @@
   "py-hs_accelc-compact_http-ip",
   "py-hs_binary_http-ip",
   "py-hs_compact_http-ip",
+  "py-hs_header_http-ip",
   "py-hs_json_http-ip",
   "py-java_accel-binary_http-ip",
   "py-java_accel-binary_http-ip-ssl",
@@ -420,6 +431,8 @@
   "py3-cpp_binary_http-ip-ssl",
   "py3-cpp_compact_http-ip",
   "py3-cpp_compact_http-ip-ssl",
+  "py3-cpp_header_http-ip",
+  "py3-cpp_header_http-ip-ssl",
   "py3-cpp_json_http-ip",
   "py3-cpp_json_http-ip-ssl",
   "py3-d_accel-binary_http-ip",
@@ -441,6 +454,7 @@
   "py3-hs_accelc-compact_http-ip",
   "py3-hs_binary_http-ip",
   "py3-hs_compact_http-ip",
+  "py3-hs_header_http-ip",
   "py3-hs_json_http-ip",
   "py3-java_accel-binary_http-ip",
   "py3-java_accel-binary_http-ip-ssl",
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index b213d1acc5..56a408e60a 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -56,6 +56,7 @@
     'binary',
     'compact',
     'json',
+    'header',
 ]
 
 
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 2164162280..ddcce8db0a 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -348,6 +348,12 @@ def get_protocol2(self, transport):
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, 
"SecondService")
 
 
+class HeaderTest(MultiplexedOptionalTest):
+    def get_protocol(self, transport):
+        factory = THeaderProtocol.THeaderProtocolFactory()
+        return factory.getProtocol(transport)
+
+
 def suite():
     suite = unittest.TestSuite()
     loader = unittest.TestLoader()
@@ -359,6 +365,8 @@ def suite():
         suite.addTest(loader.loadTestsFromTestCase(AcceleratedCompactTest))
     elif options.proto == 'compact':
         suite.addTest(loader.loadTestsFromTestCase(CompactTest))
+    elif options.proto == 'header':
+        suite.addTest(loader.loadTestsFromTestCase(HeaderTest))
     elif options.proto == 'json':
         suite.addTest(loader.loadTestsFromTestCase(JSONTest))
     elif options.proto == 'multi':
@@ -408,7 +416,7 @@ def parseArgs(self, argv):
                       dest="verbose", const=0,
                       help="minimal output")
     parser.add_option('--protocol', dest="proto", type="string",
-                      help="protocol to use, one of: accel, accelc, binary, 
compact, json, multi, multia, multiac, multic, multij")
+                      help="protocol to use, one of: accel, accelc, binary, 
compact, header, json, multi, multia, multiac, multic, multij")
     parser.add_option('--transport', dest="trans", type="string",
                       help="transport to use, one of: buffered, framed, http")
     parser.set_defaults(framed=False, http_path=None, verbose=1, 
host='localhost', port=9090, proto='binary')
@@ -431,6 +439,7 @@ def parseArgs(self, argv):
     from thrift.transport import TZlibTransport
     from thrift.protocol import TBinaryProtocol
     from thrift.protocol import TCompactProtocol
+    from thrift.protocol import THeaderProtocol
     from thrift.protocol import TJSONProtocol
     from thrift.protocol import TMultiplexedProtocol
 
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 4dc4c0744e..aba0d42988 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -181,16 +181,22 @@ def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5):
 def main(options):
     # set up the protocol factory form the --protocol option
     prot_factories = {
-        'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
-        'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory,
-        'binary': TBinaryProtocol.TBinaryProtocolFactory,
-        'compact': TCompactProtocol.TCompactProtocolFactory,
-        'json': TJSONProtocol.TJSONProtocolFactory
+        'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(),
+        'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory(),
+        'binary': TBinaryProtocol.TBinaryProtocolFactory(),
+        'compact': TCompactProtocol.TCompactProtocolFactory(),
+        'header': THeaderProtocol.THeaderProtocolFactory(allowed_client_types=[
+            THeaderTransport.THeaderClientType.HEADERS,
+            THeaderTransport.THeaderClientType.FRAMED_BINARY,
+            THeaderTransport.THeaderClientType.UNFRAMED_BINARY,
+            THeaderTransport.THeaderClientType.FRAMED_COMPACT,
+            THeaderTransport.THeaderClientType.UNFRAMED_COMPACT,
+        ]),
+        'json': TJSONProtocol.TJSONProtocolFactory(),
     }
-    pfactory_cls = prot_factories.get(options.proto, None)
-    if pfactory_cls is None:
+    pfactory = prot_factories.get(options.proto, None)
+    if pfactory is None:
         raise AssertionError('Unknown --protocol option: %s' % options.proto)
-    pfactory = pfactory_cls()
     try:
         pfactory.string_length_limit = options.string_limit
         pfactory.container_length_limit = options.container_limit
@@ -323,11 +329,13 @@ def exit_gracefully(signum, frame):
     from ThriftTest import ThriftTest
     from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
     from thrift.Thrift import TException
+    from thrift.transport import THeaderTransport
     from thrift.transport import TTransport
     from thrift.transport import TSocket
     from thrift.transport import TZlibTransport
     from thrift.protocol import TBinaryProtocol
     from thrift.protocol import TCompactProtocol
+    from thrift.protocol import THeaderProtocol
     from thrift.protocol import TJSONProtocol
     from thrift.server import TServer, TNonblockingServer, THttpServer
 
diff --git a/test/tests.json b/test/tests.json
index 72790acc9e..85a0c0797a 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -273,7 +273,8 @@
       "binary",
       "json",
       "binary:accel",
-      "compact:accelc"
+      "compact:accelc",
+      "header"
     ],
     "workdir": "py"
   },
@@ -319,7 +320,8 @@
       "binary",
       "json",
       "binary:accel",
-      "compact:accelc"
+      "compact:accelc",
+      "header"
     ],
     "workdir": "py"
   },


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


> THeader for Python
> ------------------
>
>                 Key: THRIFT-4621
>                 URL: https://issues.apache.org/jira/browse/THRIFT-4621
>             Project: Thrift
>          Issue Type: New Feature
>          Components: Python - Library
>            Reporter: Neil Williams
>            Priority: Minor
>
> I'm interested in porting THeader for the Python library. If that sounds OK, 
> I'll have a PR ready in not too long.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to