This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch go-zlib-pool
in repository https://gitbox.apache.org/repos/asf/thrift.git

commit f0cf77c9bef50b5effee232126cd6e1ffee3c078
Author: Yuxuan 'fishy' Wang <yuxuan.w...@reddit.com>
AuthorDate: Wed May 28 10:54:04 2025 -0700

    go: Add a zlib reader pool
    
    We implemented a zlib writer pool for default level when implementing
    THeader, this change also add a zlib reader pool to help speed up things
    when zlib is used.
    
    Also make TZlibTransport to use the zlib writer pool when it's using the
    default compression level.
---
 lib/go/thrift/compact_protocol_test.go | 15 +++++--
 lib/go/thrift/header_transport.go      | 22 +---------
 lib/go/thrift/pool.go                  |  2 +-
 lib/go/thrift/zlib_pool.go             | 76 ++++++++++++++++++++++++++++++++++
 lib/go/thrift/zlib_transport.go        | 30 ++++++++++----
 5 files changed, 111 insertions(+), 34 deletions(-)

diff --git a/lib/go/thrift/compact_protocol_test.go 
b/lib/go/thrift/compact_protocol_test.go
index 65f77f2c4..0d9575992 100644
--- a/lib/go/thrift/compact_protocol_test.go
+++ b/lib/go/thrift/compact_protocol_test.go
@@ -33,9 +33,18 @@ func TestReadWriteCompactProtocol(t *testing.T) {
                NewTFramedTransport(NewTMemoryBuffer()),
        }
 
-       zlib0, _ := NewTZlibTransport(NewTMemoryBuffer(), 0)
-       zlib6, _ := NewTZlibTransport(NewTMemoryBuffer(), 6)
-       zlib9, _ := NewTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 
9)
+       newTZlibTransport := func(trans TTransport, level int) *TZlibTransport {
+               t.Helper()
+               zlibTrans, err := NewTZlibTransport(trans, level)
+               if err != nil {
+                       t.Fatalf("NewTZlibTransport returned error: %v", err)
+               }
+               return zlibTrans
+       }
+
+       zlib0 := newTZlibTransport(NewTMemoryBuffer(), 0)
+       zlib6 := newTZlibTransport(NewTMemoryBuffer(), 6)
+       zlib9 := newTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 9)
        transports = append(transports, zlib0, zlib6, zlib9)
 
        for _, trans := range transports {
diff --git a/lib/go/thrift/header_transport.go 
b/lib/go/thrift/header_transport.go
index d6d64160a..0e5c7ec1d 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -22,7 +22,6 @@ package thrift
 import (
        "bufio"
        "bytes"
-       "compress/zlib"
        "context"
        "encoding/binary"
        "errors"
@@ -166,7 +165,7 @@ func (tr *TransformReader) AddTransform(id 
THeaderTransformID) error {
        case TransformNone:
                // no-op
        case TransformZlib:
-               readCloser, err := zlib.NewReader(tr.Reader)
+               readCloser, err := newZlibReader(tr.Reader)
                if err != nil {
                        return err
                }
@@ -211,25 +210,6 @@ func (tw *TransformWriter) Close() error {
        return nil
 }
 
-var zlibDefaultLevelWriterPool = newPool(
-       func() *zlib.Writer {
-               return zlib.NewWriter(nil)
-       },
-       nil,
-)
-
-type zlibPoolCloser struct {
-       writer *zlib.Writer
-}
-
-func (z *zlibPoolCloser) Close() error {
-       defer func() {
-               z.writer.Reset(nil)
-               zlibDefaultLevelWriterPool.put(&z.writer)
-       }()
-       return z.writer.Close()
-}
-
 // AddTransform adds a transform.
 func (tw *TransformWriter) AddTransform(id THeaderTransformID) error {
        switch id {
diff --git a/lib/go/thrift/pool.go b/lib/go/thrift/pool.go
index 1d623d422..6912f3ea5 100644
--- a/lib/go/thrift/pool.go
+++ b/lib/go/thrift/pool.go
@@ -43,7 +43,7 @@ func newPool[T any](generate func() *T, reset func(*T)) 
*pool[T] {
        }
        return &pool[T]{
                pool: sync.Pool{
-                       New: func() interface{} {
+                       New: func() any {
                                return generate()
                        },
                },
diff --git a/lib/go/thrift/zlib_pool.go b/lib/go/thrift/zlib_pool.go
new file mode 100644
index 000000000..c38e56850
--- /dev/null
+++ b/lib/go/thrift/zlib_pool.go
@@ -0,0 +1,76 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+ */
+
+package thrift
+
+import (
+       "compress/zlib"
+       "io"
+       "sync"
+)
+
+type zlibReader interface {
+       io.ReadCloser
+       zlib.Resetter
+}
+
+var zlibReaderPool sync.Pool
+
+func newZlibReader(r io.Reader) (io.ReadCloser, error) {
+       if reader, _ := zlibReaderPool.Get().(*wrappedZlibReader); reader != 
nil {
+               if err := reader.Reset(r, nil); err == nil {
+                       return reader, nil
+               }
+       }
+       reader, err := zlib.NewReader(r)
+       if err != nil {
+               return nil, err
+       }
+       return &wrappedZlibReader{reader.(zlibReader)}, nil
+}
+
+type wrappedZlibReader struct {
+       zlibReader
+}
+
+func (wr *wrappedZlibReader) Close() error {
+       defer func() {
+               zlibReaderPool.Put(wr)
+       }()
+       return wr.zlibReader.Close()
+}
+
+var zlibDefaultLevelWriterPool = newPool(
+       func() *zlib.Writer {
+               return zlib.NewWriter(nil)
+       },
+       nil,
+)
+
+type zlibPoolCloser struct {
+       writer *zlib.Writer
+}
+
+func (z *zlibPoolCloser) Close() error {
+       defer func() {
+               z.writer.Reset(nil)
+               zlibDefaultLevelWriterPool.put(&z.writer)
+       }()
+       return z.writer.Close()
+}
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index cefe1f994..c3863696d 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -33,9 +33,10 @@ type TZlibTransportFactory struct {
 
 // TZlibTransport is a TTransport implementation that makes use of zlib 
compression.
 type TZlibTransport struct {
-       reader    io.ReadCloser
-       transport TTransport
-       writer    *zlib.Writer
+       reader      io.ReadCloser
+       transport   TTransport
+       writer      *zlib.Writer
+       writeCloser io.Closer
 }
 
 // GetTransport constructs a new instance of NewTZlibTransport
@@ -64,14 +65,25 @@ func NewTZlibTransportFactoryWithFactory(level int, factory 
TTransportFactory) *
 
 // NewTZlibTransport constructs a new instance of TZlibTransport
 func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) {
-       w, err := zlib.NewWriterLevel(trans, level)
-       if err != nil {
-               return nil, err
+       var w *zlib.Writer
+       var writeCloser io.Closer
+       if level == zlib.DefaultCompression {
+               w = zlibDefaultLevelWriterPool.get()
+               w.Reset(trans)
+               writeCloser = &zlibPoolCloser{writer: w}
+       } else {
+               writer, err := zlib.NewWriterLevel(trans, level)
+               if err != nil {
+                       return nil, err
+               }
+               w = writer
+               writeCloser = writer
        }
 
        return &TZlibTransport{
-               writer:    w,
-               transport: trans,
+               writer:      w,
+               writeCloser: writeCloser,
+               transport:   trans,
        }, nil
 }
 
@@ -109,7 +121,7 @@ func (z *TZlibTransport) Open() error {
 
 func (z *TZlibTransport) Read(p []byte) (int, error) {
        if z.reader == nil {
-               r, err := zlib.NewReader(z.transport)
+               r, err := newZlibReader(z.transport)
                if err != nil {
                        return 0, NewTTransportExceptionFromError(err)
                }

Reply via email to