Repository: spark
Updated Branches:
  refs/heads/branch-2.2 bd12eb75d -> 4f10aff40


[SPARK-25253][PYSPARK] Refactor local connection & auth code

This eliminates some duplication in the code to connect to a server on 
localhost to talk directly to the jvm.  Also it gives consistent ipv6 and error 
handling.  Two other incidental changes, that shouldn't matter:
1) python barrier tasks perform authentication immediately (rather than waiting 
for the BARRIER_FUNCTION indicator)
2) for `rdd._load_from_socket`, the timeout is only increased after 
authentication.

Closes #22247 from squito/py_connection_refactor.

Authored-by: Imran Rashid <iras...@cloudera.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>
(cherry picked from commit 38391c9aa8a88fcebb337934f30298a32d91596b)
(cherry picked from commit a2a54a5f49364a1825932c9f04eb0ff82dd7d465)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fc1c4e7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fc1c4e7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fc1c4e7d

Branch: refs/heads/branch-2.2
Commit: fc1c4e7d24f7d0afb3b79d66aa9812e7dddc2f38
Parents: bd12eb7
Author: Imran Rashid <iras...@cloudera.com>
Authored: Wed Aug 29 09:47:38 2018 +0800
Committer: Imran Rashid <iras...@cloudera.com>
Committed: Tue Sep 25 11:45:59 2018 -0500

----------------------------------------------------------------------
 python/pyspark/java_gateway.py | 32 +++++++++++++++++++++++++++++++-
 python/pyspark/rdd.py          | 24 ++----------------------
 python/pyspark/worker.py       |  7 ++-----
 3 files changed, 35 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 7abf2c1..191dfce 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -133,7 +133,7 @@ def launch_gateway(conf=None):
     return gateway
 
 
-def do_server_auth(conn, auth_secret):
+def _do_server_auth(conn, auth_secret):
     """
     Performs the authentication protocol defined by the SocketAuthHelper class 
on the given
     file-like object 'conn'.
@@ -144,3 +144,33 @@ def do_server_auth(conn, auth_secret):
     if reply != "ok":
         conn.close()
         raise Exception("Unexpected reply from iterator server.")
+
+
+def local_connect_and_auth(port, auth_secret):
+    """
+    Connect to local host, authenticate with it, and return a (sockfile,sock) 
for that connection.
+    Handles IPV4 & IPV6, does some error handling.
+    :param port
+    :param auth_secret
+    :return: a tuple with (sockfile, sock)
+    """
+    sock = None
+    errors = []
+    # Support for both IPv4 and IPv6.
+    # On most of IPv6-ready systems, IPv6 will take precedence.
+    for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
+        af, socktype, proto, _, sa = res
+        try:
+            sock = socket.socket(af, socktype, proto)
+            sock.settimeout(15)
+            sock.connect(sa)
+            sockfile = sock.makefile("rwb", 65536)
+            _do_server_auth(sockfile, auth_secret)
+            return (sockfile, sock)
+        except socket.error as e:
+            emsg = _exception_message(e)
+            errors.append("tried to connect to %s, but an error occured: %s" % 
(sa, emsg))
+            sock.close()
+            sock = None
+    else:
+        raise Exception("could not open socket: %s" % errors)

http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 864cebb..7d84cbd 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ if sys.version > '3':
 else:
     from itertools import imap as map, ifilter as filter
 
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
     PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@@ -122,30 +122,10 @@ def _parse_memory(s):
 
 
 def _load_from_socket(sock_info, serializer):
-    port, auth_secret = sock_info
-    sock = None
-    # Support for both IPv4 and IPv6.
-    # On most of IPv6-ready systems, IPv6 will take precedence.
-    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
-        af, socktype, proto, canonname, sa = res
-        sock = socket.socket(af, socktype, proto)
-        try:
-            sock.settimeout(15)
-            sock.connect(sa)
-        except socket.error:
-            sock.close()
-            sock = None
-            continue
-        break
-    if not sock:
-        raise Exception("could not open socket")
+    (sockfile, sock) = local_connect_and_auth(*sock_info)
     # The RDD materialization time is unpredicable, if we set a timeout for 
socket reading
     # operation, it will very possibly fail. See SPARK-18281.
     sock.settimeout(None)
-
-    sockfile = sock.makefile("rwb", 65536)
-    do_server_auth(sockfile, auth_secret)
-
     # The socket will be automatically closed when garbage-collected.
     return serializer.load_stream(sockfile)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0c8996e..f3cb6ae 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,7 +27,7 @@ import traceback
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
@@ -212,8 +212,5 @@ if __name__ == '__main__':
     # Read information about how to connect back to the JVM from the 
environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.connect(("127.0.0.1", java_port))
-    sock_file = sock.makefile("rwb", 65536)
-    do_server_auth(sock_file, auth_secret)
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
     main(sock_file, sock_file)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to