diff --git 
a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java
 
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java
index 5970f85c1745..38cae631a975 100644
--- 
a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java
+++ 
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java
@@ -25,6 +25,7 @@
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Supplier;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.sdk.fn.channel.SocketAddressFactory;
@@ -42,6 +43,9 @@
 
 /** A {@link Server gRPC server} factory. */
 public abstract class ServerFactory {
+
+  private static final int KEEP_ALIVE_TIME_SEC = 20;
+
   /** Create a default {@link InetSocketAddressServerFactory}. */
   public static ServerFactory createDefault() {
     return new InetSocketAddressServerFactory(UrlFactory.createDefault());
@@ -144,7 +148,8 @@ private static Server createServer(List<BindableService> 
services, InetSocketAdd
           NettyServerBuilder.forPort(socket.getPort())
               // Set the message size to max value here. The actual size is 
governed by the
               // buffer size in the layers above.
-              .maxMessageSize(Integer.MAX_VALUE);
+              .maxMessageSize(Integer.MAX_VALUE)
+              .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS);
       services
           .stream()
           .forEach(
@@ -200,7 +205,8 @@ private static Server createServer(
               .channelType(EpollServerDomainSocketChannel.class)
               .workerEventLoopGroup(new EpollEventLoopGroup())
               .bossEventLoopGroup(new EpollEventLoopGroup())
-              .maxMessageSize(Integer.MAX_VALUE);
+              .maxMessageSize(Integer.MAX_VALUE)
+              .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS);
       for (BindableService service : services) {
         // Wrap the service to extract headers
         builder.addService(
@@ -249,7 +255,8 @@ private static Server createServer(List<BindableService> 
services, InetSocketAdd
               .channelType(EpollServerSocketChannel.class)
               .workerEventLoopGroup(new EpollEventLoopGroup())
               .bossEventLoopGroup(new EpollEventLoopGroup())
-              .maxMessageSize(Integer.MAX_VALUE);
+              .maxMessageSize(Integer.MAX_VALUE)
+              .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS);
       for (BindableService service : services) {
         // Wrap the service to extract headers
         builder.addService(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index a39a996478d6..1272b0e31b52 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -61,6 +61,7 @@
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import sdk_worker
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.transforms import trigger
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.utils import profiler
@@ -830,7 +831,8 @@ def __init__(self, external_payload, state):
 
   def start_worker(self):
     stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub(
-        grpc.insecure_channel(self._external_payload.endpoint.url))
+        GRPCChannelFactory.insecure_channel(
+            self._external_payload.endpoint.url))
     response = stub.NotifyRunnerAvailable(
         beam_fn_api_pb2.NotifyRunnerAvailableRequest(
             control_endpoint=endpoints_pb2.ApiServiceDescriptor(
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py 
b/sdks/python/apache_beam/runners/portability/portable_runner.py
index f6108ed030d0..d2bf31b2c61c 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -44,6 +44,7 @@
 from apache_beam.runners.portability.job_server import DockerizedJobServer
 from apache_beam.runners.worker import sdk_worker
 from apache_beam.runners.worker import sdk_worker_main
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 
 __all__ = ['PortableRunner']
 
@@ -188,7 +189,7 @@ def run_pipeline(self, pipeline, options):
                  for k, v in options.get_all_options().items()
                  if v is not None}
 
-    channel = grpc.insecure_channel(job_endpoint)
+    channel = GRPCChannelFactory.insecure_channel(job_endpoint)
     grpc.channel_ready_future(channel).result()
     job_service = beam_job_api_pb2_grpc.JobServiceStub(channel)
 
@@ -212,7 +213,8 @@ def send_prepare_request(max_retries=5):
     prepare_response = send_prepare_request()
     if prepare_response.artifact_staging_endpoint.url:
       stager = portable_stager.PortableStager(
-          
grpc.insecure_channel(prepare_response.artifact_staging_endpoint.url),
+          GRPCChannelFactory.insecure_channel(
+              prepare_response.artifact_staging_endpoint.url),
           prepare_response.staging_session_token)
       retrieval_token, _ = stager.stage_job_resources(
           options,
diff --git 
a/sdks/python/apache_beam/runners/portability/portable_runner_test.py 
b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index a37149d633bc..80cd7bf8be2f 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -43,6 +43,7 @@
 from apache_beam.runners.portability import portable_runner
 from apache_beam.runners.portability.local_job_service import LocalJobServicer
 from apache_beam.runners.portability.portable_runner import PortableRunner
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 
 
 class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -93,7 +94,7 @@ def _start_local_runner_subprocess_job_service(cls):
     cls._subprocess = subprocess.Popen(cls._subprocess_command(port))
     address = 'localhost:%d' % port
     job_service = beam_job_api_pb2_grpc.JobServiceStub(
-        grpc.insecure_channel(address))
+        GRPCChannelFactory.insecure_channel(address))
     logging.info('Waiting for server to be ready...')
     start = time.time()
     timeout = 30
diff --git a/sdks/python/apache_beam/runners/worker/channel_factory.py 
b/sdks/python/apache_beam/runners/worker/channel_factory.py
new file mode 100644
index 000000000000..d0823fa54842
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/channel_factory.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Factory to create grpc channel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import grpc
+
+
+class GRPCChannelFactory(grpc.StreamStreamClientInterceptor):
+  DEFAULT_OPTIONS = [("grpc.keepalive_time_ms", 20000)]
+
+  def __init__(self):
+    pass
+
+  @staticmethod
+  def insecure_channel(target, options=None):
+    if options is None:
+      options = []
+    return grpc.insecure_channel(
+        target, options=options + GRPCChannelFactory.DEFAULT_OPTIONS)
+
+  @staticmethod
+  def secure_channel(target, credentials, options=None):
+    if options is None:
+      options = []
+    return grpc.secure_channel(
+        target, credentials,
+        options=options + GRPCChannelFactory.DEFAULT_OPTIONS)
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index 12bd3764d7c8..fc8f9cca887a 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -37,6 +37,7 @@
 from apache_beam.coders import coder_impl
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.worker_id_interceptor import 
WorkerIdInterceptor
 
 # This module is experimental. No backwards-compatibility guarantees.
@@ -343,9 +344,10 @@ def create_data_channel(self, remote_grpc_port):
                              ("grpc.max_send_message_length", -1)]
           grpc_channel = None
           if self._credentials is None:
-            grpc_channel = grpc.insecure_channel(url, options=channel_options)
+            grpc_channel = GRPCChannelFactory.insecure_channel(
+                url, options=channel_options)
           else:
-            grpc_channel = grpc.secure_channel(
+            grpc_channel = GRPCChannelFactory.secure_channel(
                 url, self._credentials, options=channel_options)
           # Add workerId to the grpc channel
           grpc_channel = grpc.intercept_channel(grpc_channel,
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py 
b/sdks/python/apache_beam/runners/worker/log_handler.py
index f72c7b00779b..cbd68f5de99d 100644
--- a/sdks/python/apache_beam/runners/worker/log_handler.py
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -30,6 +30,7 @@
 
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.worker_id_interceptor import 
WorkerIdInterceptor
 
 # This module is experimental. No backwards-compatibility guarantees.
@@ -61,7 +62,7 @@ def __init__(self, log_service_descriptor):
     self._dropped_logs = 0
     self._log_entry_queue = queue.Queue(maxsize=self._QUEUE_SIZE)
 
-    ch = grpc.insecure_channel(log_service_descriptor.url)
+    ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url)
     # Make sure the channel is ready to avoid [BEAM-4649]
     grpc.channel_ready_future(ch).result(timeout=60)
     self._log_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 2d0b61d82263..b23cf68b1dff 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -41,6 +41,7 @@
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.worker_id_interceptor import 
WorkerIdInterceptor
 
 
@@ -57,10 +58,12 @@ def __init__(
     self._worker_id = worker_id
     if credentials is None:
       logging.info('Creating insecure control channel for %s.', 
control_address)
-      self._control_channel = grpc.insecure_channel(control_address)
+      self._control_channel = GRPCChannelFactory.insecure_channel(
+          control_address)
     else:
       logging.info('Creating secure control channel for %s.', control_address)
-      self._control_channel = grpc.secure_channel(control_address, credentials)
+      self._control_channel = GRPCChannelFactory.secure_channel(
+          control_address, credentials)
     grpc.channel_ready_future(self._control_channel).result(timeout=60)
     logging.info('Control channel established.')
 
@@ -355,7 +358,7 @@ def create_state_handler(self, api_service_descriptor):
       with self._lock:
         if url not in self._state_handler_cache:
           logging.info('Creating insecure state channel for %s', url)
-          grpc_channel = grpc.insecure_channel(
+          grpc_channel = GRPCChannelFactory.insecure_channel(
               url,
               # Options to have no limits (-1) on the size of the messages
               # received or sent over the data plane. The actual buffer size is


With regards,
Apache Git Services

Reply via email to