Repository: spark Updated Branches: refs/heads/branch-2.1 a3eb07db3 -> b2e0f68f6
[PYSPARK] Updates to Accumulators (cherry picked from commit 15fc2372269159ea2556b028d4eb8860c4108650) Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b2e0f68f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b2e0f68f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b2e0f68f Branch: refs/heads/branch-2.1 Commit: b2e0f68f615cbe2cf74f9813ece76c311fe8e911 Parents: a3eb07d Author: LucaCanali <luca.can...@cern.ch> Authored: Wed Jul 18 23:19:02 2018 +0200 Committer: Imran Rashid <iras...@cloudera.com> Committed: Fri Aug 3 16:30:40 2018 -0500 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 12 +++-- python/pyspark/accumulators.py | 53 +++++++++++++++----- python/pyspark/context.py | 5 +- 3 files changed, 53 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b1190b9..de548e8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -886,8 +886,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By */ private[spark] class PythonAccumulatorV2( @transient private val serverHost: String, - private val serverPort: Int) - extends CollectionAccumulator[Array[Byte]] { + private val serverPort: Int, + private val secretToken: String) + extends CollectionAccumulator[Array[Byte]] with Logging{ Utils.checkHost(serverHost, "Expected hostname") @@ -902,12 +903,17 @@ private[spark] class PythonAccumulatorV2( private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) + logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) } socket } // Need to override so the types match with PythonFunction - override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) + override def copyAndReset(): PythonAccumulatorV2 = { + new PythonAccumulatorV2(serverHost, serverPort, secretToken) + } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/python/pyspark/accumulators.py ---------------------------------------------------------------------- diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf5..bc0be07 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -228,20 +228,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry - while not self.server.server_shutdown: - # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r: - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + auth_token = self.server.auth_token + + def poll(func): + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + if func(): + break + + def accum_updates(): + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + return False + + def authenticate_and_accum_updates(): + received_token = self.rfile.read(len(auth_token)) + if isinstance(received_token, bytes): + received_token = received_token.decode("utf-8") + if (received_token == auth_token): + accum_updates() + # we've authenticated, we can break out of the first loop now + return True + else: + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") + + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore + poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): + def __init__(self, server_address, RequestHandlerClass, auth_token): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.auth_token = auth_token + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -254,9 +283,9 @@ class AccumulatorServer(SocketServer.TCPServer): self.server_close() -def _start_update_server(): +def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b6dced5..8e209f3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -185,9 +185,10 @@ class SparkContext(object): # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server - self._accumulatorServer = accumulators._start_update_server() + auth_token = self._gateway.gateway_parameters.auth_token + self._accumulatorServer = accumulators._start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org