Repository: spark Updated Branches: refs/heads/branch-2.2 bd12eb75d -> 4f10aff40
[SPARK-25253][PYSPARK] Refactor local connection & auth code This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm. Also it gives consistent ipv6 and error handling. Two other incidental changes, that shouldn't matter: 1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator) 2) for `rdd._load_from_socket`, the timeout is only increased after authentication. Closes #22247 from squito/py_connection_refactor. Authored-by: Imran Rashid <iras...@cloudera.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> (cherry picked from commit 38391c9aa8a88fcebb337934f30298a32d91596b) (cherry picked from commit a2a54a5f49364a1825932c9f04eb0ff82dd7d465) Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fc1c4e7d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fc1c4e7d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fc1c4e7d Branch: refs/heads/branch-2.2 Commit: fc1c4e7d24f7d0afb3b79d66aa9812e7dddc2f38 Parents: bd12eb7 Author: Imran Rashid <iras...@cloudera.com> Authored: Wed Aug 29 09:47:38 2018 +0800 Committer: Imran Rashid <iras...@cloudera.com> Committed: Tue Sep 25 11:45:59 2018 -0500 ---------------------------------------------------------------------- python/pyspark/java_gateway.py | 32 +++++++++++++++++++++++++++++++- python/pyspark/rdd.py | 24 ++---------------------- python/pyspark/worker.py | 7 ++----- 3 files changed, 35 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/java_gateway.py ---------------------------------------------------------------------- diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 7abf2c1..191dfce 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -133,7 +133,7 @@ def launch_gateway(conf=None): return gateway -def do_server_auth(conn, auth_secret): +def _do_server_auth(conn, auth_secret): """ Performs the authentication protocol defined by the SocketAuthHelper class on the given file-like object 'conn'. @@ -144,3 +144,33 @@ def do_server_auth(conn, auth_secret): if reply != "ok": conn.close() raise Exception("Unexpected reply from iterator server.") + + +def local_connect_and_auth(port, auth_secret): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param port + :param auth_secret + :return: a tuple with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(15) + sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + else: + raise Exception("could not open socket: %s" % errors) http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 864cebb..7d84cbd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ if sys.version > '3': else: from itertools import imap as map, ifilter as filter -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ @@ -122,30 +122,10 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - port, auth_secret = sock_info - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - - sockfile = sock.makefile("rwb", 65536) - do_server_auth(sockfile, auth_secret) - # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) http://git-wip-us.apache.org/repos/asf/spark/blob/fc1c4e7d/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0c8996e..f3cb6ae 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,7 +27,7 @@ import traceback from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ @@ -212,8 +212,5 @@ if __name__ == '__main__': # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) - do_server_auth(sock_file, auth_secret) + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org