Repository: spark
Updated Branches:
  refs/heads/branch-2.1 a3eb07db3 -> b2e0f68f6


[PYSPARK] Updates to Accumulators

(cherry picked from commit 15fc2372269159ea2556b028d4eb8860c4108650)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b2e0f68f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b2e0f68f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b2e0f68f

Branch: refs/heads/branch-2.1
Commit: b2e0f68f615cbe2cf74f9813ece76c311fe8e911
Parents: a3eb07d
Author: LucaCanali <luca.can...@cern.ch>
Authored: Wed Jul 18 23:19:02 2018 +0200
Committer: Imran Rashid <iras...@cloudera.com>
Committed: Fri Aug 3 16:30:40 2018 -0500

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 12 +++--
 python/pyspark/accumulators.py                  | 53 +++++++++++++++-----
 python/pyspark/context.py                       |  5 +-
 3 files changed, 53 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1190b9..de548e8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -886,8 +886,9 @@ class BytesToString extends 
org.apache.spark.api.java.function.Function[Array[By
  */
 private[spark] class PythonAccumulatorV2(
     @transient private val serverHost: String,
-    private val serverPort: Int)
-  extends CollectionAccumulator[Array[Byte]] {
+    private val serverPort: Int,
+    private val secretToken: String)
+  extends CollectionAccumulator[Array[Byte]] with Logging{
 
   Utils.checkHost(serverHost, "Expected hostname")
 
@@ -902,12 +903,17 @@ private[spark] class PythonAccumulatorV2(
   private def openSocket(): Socket = synchronized {
     if (socket == null || socket.isClosed) {
       socket = new Socket(serverHost, serverPort)
+      logInfo(s"Connected to AccumulatorServer at host: $serverHost port: 
$serverPort")
+      // send the secret just for the initial authentication when opening a 
new connection
+      
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
     }
     socket
   }
 
   // Need to override so the types match with PythonFunction
-  override def copyAndReset(): PythonAccumulatorV2 = new 
PythonAccumulatorV2(serverHost, serverPort)
+  override def copyAndReset(): PythonAccumulatorV2 = {
+    new PythonAccumulatorV2(serverHost, serverPort, secretToken)
+  }
 
   override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): 
Unit = synchronized {
     val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]

http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/python/pyspark/accumulators.py
----------------------------------------------------------------------
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 6ef8cf5..bc0be07 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -228,20 +228,49 @@ class 
_UpdateRequestHandler(SocketServer.StreamRequestHandler):
 
     def handle(self):
         from pyspark.accumulators import _accumulatorRegistry
-        while not self.server.server_shutdown:
-            # Poll every 1 second for new data -- don't block in case of 
shutdown.
-            r, _, _ = select.select([self.rfile], [], [], 1)
-            if self.rfile in r:
-                num_updates = read_int(self.rfile)
-                for _ in range(num_updates):
-                    (aid, update) = pickleSer._read_with_length(self.rfile)
-                    _accumulatorRegistry[aid] += update
-                # Write a byte in acknowledgement
-                self.wfile.write(struct.pack("!b", 1))
+        auth_token = self.server.auth_token
+
+        def poll(func):
+            while not self.server.server_shutdown:
+                # Poll every 1 second for new data -- don't block in case of 
shutdown.
+                r, _, _ = select.select([self.rfile], [], [], 1)
+                if self.rfile in r:
+                    if func():
+                        break
+
+        def accum_updates():
+            num_updates = read_int(self.rfile)
+            for _ in range(num_updates):
+                (aid, update) = pickleSer._read_with_length(self.rfile)
+                _accumulatorRegistry[aid] += update
+            # Write a byte in acknowledgement
+            self.wfile.write(struct.pack("!b", 1))
+            return False
+
+        def authenticate_and_accum_updates():
+            received_token = self.rfile.read(len(auth_token))
+            if isinstance(received_token, bytes):
+                received_token = received_token.decode("utf-8")
+            if (received_token == auth_token):
+                accum_updates()
+                # we've authenticated, we can break out of the first loop now
+                return True
+            else:
+                raise Exception(
+                    "The value of the provided token to the AccumulatorServer 
is not correct.")
+
+        # first we keep polling till we've received the authentication token
+        poll(authenticate_and_accum_updates)
+        # now we've authenticated, don't need to check for the token anymore
+        poll(accum_updates)
 
 
 class AccumulatorServer(SocketServer.TCPServer):
 
+    def __init__(self, server_address, RequestHandlerClass, auth_token):
+        SocketServer.TCPServer.__init__(self, server_address, 
RequestHandlerClass)
+        self.auth_token = auth_token
+
     """
     A simple TCP server that intercepts shutdown() in order to interrupt
     our continuous polling on the handler.
@@ -254,9 +283,9 @@ class AccumulatorServer(SocketServer.TCPServer):
         self.server_close()
 
 
-def _start_update_server():
+def _start_update_server(auth_token):
     """Start a TCP server to receive accumulator updates in a daemon thread, 
and returns it"""
-    server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
+    server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, 
auth_token)
     thread = threading.Thread(target=server.serve_forever)
     thread.daemon = True
     thread.start()

http://git-wip-us.apache.org/repos/asf/spark/blob/b2e0f68f/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b6dced5..8e209f3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -185,9 +185,10 @@ class SparkContext(object):
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        self._accumulatorServer = accumulators._start_update_server()
+        auth_token = self._gateway.gateway_parameters.auth_token
+        self._accumulatorServer = accumulators._start_update_server(auth_token)
         (host, port) = self._accumulatorServer.server_address
-        self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
+        self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, 
auth_token)
         self._jsc.sc().register(self._javaAccumulator)
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to