This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch 0.19.0
in repository https://gitbox.apache.org/repos/asf/thrift.git

commit ee1a7ea35b72ab95445106410343088cf66ac173
Author: Yuxuan 'fishy' Wang <yuxuan.w...@reddit.com>
AuthorDate: Wed Aug 9 15:06:37 2023 -0700

    THRIFT-5731: Handle ErrAbandonRequest automatically
    
    Also add a test to verify the behavior.
    
    The test helped me to found a bug in TSimpleServer that didn't handle
    the ErrAbandonRequest case correctly, so fix the bug as well.
    
    client: go
---
 compiler/cpp/src/thrift/generate/t_go_generator.cc | 19 +++--
 lib/go/README.md                                   | 12 ++-
 .../test/tests/server_connectivity_check_test.go   | 88 ++++++++++++++++++++++
 lib/go/thrift/simple_server.go                     |  9 ++-
 4 files changed, 119 insertions(+), 9 deletions(-)

diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc 
b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 90353ce9b..54422c826 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -2797,13 +2797,13 @@ void 
t_go_generator::generate_process_function(t_service* tservice, t_function*
     f_types_ << indent() << "if thrift.ServerConnectivityCheckInterval > 0 {" 
<< endl;
 
     indent_up();
-    f_types_ << indent() << "var cancel context.CancelFunc" << endl;
-    f_types_ << indent() << "ctx, cancel = context.WithCancel(ctx)" << endl;
-    f_types_ << indent() << "defer cancel()" << endl;
+    f_types_ << indent() << "var cancel context.CancelCauseFunc" << endl;
+    f_types_ << indent() << "ctx, cancel = context.WithCancelCause(ctx)" << 
endl;
+    f_types_ << indent() << "defer cancel(nil)" << endl;
     f_types_ << indent() << "var tickerCtx context.Context" << endl;
     f_types_ << indent() << "tickerCtx, tickerCancel = 
context.WithCancel(context.Background())" << endl;
     f_types_ << indent() << "defer tickerCancel()" << endl;
-    f_types_ << indent() << "go func(ctx context.Context, cancel 
context.CancelFunc) {" << endl;
+    f_types_ << indent() << "go func(ctx context.Context, cancel 
context.CancelCauseFunc) {" << endl;
 
     indent_up();
     f_types_ << indent() << "ticker := 
time.NewTicker(thrift.ServerConnectivityCheckInterval)" << endl;
@@ -2821,7 +2821,7 @@ void t_go_generator::generate_process_function(t_service* 
tservice, t_function*
     indent_up();
     f_types_ << indent() << "if !iprot.Transport().IsOpen() {" << endl;
     indent_up();
-    f_types_ << indent() << "cancel()" << endl;
+    f_types_ << indent() << "cancel(thrift.ErrAbandonRequest)" << endl;
     f_types_ << indent() << "return" << endl;
     indent_down();
     f_types_ << indent() << "}" << endl;
@@ -2901,6 +2901,15 @@ void 
t_go_generator::generate_process_function(t_service* tservice, t_function*
     f_types_ << indent() << "return false, thrift.WrapTException(err2)" << 
endl;
     indent_down();
     f_types_ << indent() << "}" << endl;
+    f_types_ << indent() << "if errors.Is(err2, context.Canceled) {" << endl;
+    indent_up();
+    f_types_ << indent() << "if err := context.Cause(ctx); errors.Is(err, 
thrift.ErrAbandonRequest) {" << endl;
+    indent_up();
+    f_types_ << indent() << "return false, thrift.WrapTException(err)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
 
     string exc(tmp("_exc"));
     f_types_ << indent() << exc << " := 
thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
diff --git a/lib/go/README.md b/lib/go/README.md
index b2cf1df12..0aa4f1bc6 100644
--- a/lib/go/README.md
+++ b/lib/go/README.md
@@ -108,13 +108,19 @@ The context object passed into the server handler 
function will be canceled when
 the client closes the connection (this is a best effort check, not a guarantee
 -- there's no guarantee that the context object is always canceled when client
 closes the connection, but when it's canceled you can always assume the client
-closed the connection). When implementing Go Thrift server, you can take
-advantage of that to abandon requests that's no longer needed:
+closed the connection). The cause of the cancellation (via 
`context.Cause(ctx)`)
+would also be set to `thrift.ErrAbandonRequest`.
+
+When implementing Go Thrift server, you can take advantage of that to abandon
+requests that's no longer needed by returning `thrift.ErrAbandonRequest`:
 
     func MyEndpoint(ctx context.Context, req *thriftRequestType) 
(*thriftResponseType, error) {
         ...
         if ctx.Err() == context.Canceled {
             return nil, thrift.ErrAbandonRequest
+            // Or just return ctx.Err(), compiler generated processor code will
+            // handle it for you automatically:
+            // return nil, ctx.Err()
         }
         ...
     }
@@ -155,4 +161,4 @@ will wait for all the client connections to be closed 
gracefully with
 zero err time. Otherwise, the stop will wait for all the client 
 connections to be closed gracefully util thrift.ServerStopTimeout is 
 reached, and client connections that are not closed after 
thrift.ServerStopTimeout 
-will be closed abruptly which may cause some client errors.
\ No newline at end of file
+will be closed abruptly which may cause some client errors.
diff --git a/lib/go/test/tests/server_connectivity_check_test.go 
b/lib/go/test/tests/server_connectivity_check_test.go
new file mode 100644
index 000000000..51710eda2
--- /dev/null
+++ b/lib/go/test/tests/server_connectivity_check_test.go
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+       "context"
+       "runtime/debug"
+       "testing"
+       "time"
+
+       
"github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest"
+       "github.com/apache/thrift/lib/go/thrift"
+)
+
+func TestServerConnectivityCheck(t *testing.T) {
+       const (
+               // Server will sleep for longer than client is willing to wait
+               // so client will close the connection.
+               serverSleep         = 50 * time.Millisecond
+               clientSocketTimeout = time.Millisecond
+       )
+       serverSocket, err := thrift.NewTServerSocket(":0")
+       if err != nil {
+               t.Fatalf("failed to create server socket: %v", err)
+       }
+       processor := 
clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(fakeClientMiddlewareExceptionTestHandler(
+               func(ctx context.Context) 
(*clientmiddlewareexceptiontest.FooResponse, error) {
+                       time.Sleep(serverSleep)
+                       err := ctx.Err()
+                       if err == nil {
+                               t.Error("Expected server ctx to be cancelled, 
did not happen")
+                               return 
new(clientmiddlewareexceptiontest.FooResponse), nil
+                       }
+                       return nil, err
+               },
+       ))
+       server := thrift.NewTSimpleServer2(processor, serverSocket)
+       if err := server.Listen(); err != nil {
+               t.Fatalf("failed to listen server: %v", err)
+       }
+       server.SetLogger(func(msg string) {
+               t.Errorf("Server logger called with %q", msg)
+               t.Errorf("Server logger callstack:\n%s", debug.Stack())
+       })
+       addr := serverSocket.Addr().String()
+       go server.Serve()
+       t.Cleanup(func() {
+               server.Stop()
+       })
+
+       cfg := &thrift.TConfiguration{
+               SocketTimeout: clientSocketTimeout,
+       }
+       socket := thrift.NewTSocketConf(addr, cfg)
+       if err := socket.Open(); err != nil {
+               t.Fatalf("failed to create client connection: %v", err)
+       }
+       t.Cleanup(func() {
+               socket.Close()
+       })
+       inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+       outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+       client := thrift.NewTStandardClient(inProtocol, outProtocol)
+       ctx, cancel := context.WithTimeout(context.Background(), 
clientSocketTimeout)
+       defer cancel()
+       _, err = 
clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(ctx)
+       socket.Close()
+       if err == nil {
+               t.Error("Expected client to time out, did not happen")
+       }
+}
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index c5c14feed..d4f555ccd 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -24,6 +24,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "net"
        "sync"
        "sync/atomic"
        "time"
@@ -354,7 +355,13 @@ func (p *TSimpleServer) processRequests(client TTransport) 
(err error) {
 
                ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
                if errors.Is(err, ErrAbandonRequest) {
-                       return client.Close()
+                       err := client.Close()
+                       if errors.Is(err, net.ErrClosed) {
+                               // In this case, it's kinda expected to get
+                               // net.ErrClosed, treat that as no-error
+                               return nil
+                       }
+                       return err
                }
                if errors.As(err, new(TTransportException)) && err != nil {
                        return err

Reply via email to