Repository: spark
Updated Branches:
  refs/heads/master 5c5396cb4 -> 90d575421


[SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of 
Accumulator V2

## What changes were proposed in this pull request?

Move the internals of the PySpark accumulator API from the old deprecated API 
on top of the new accumulator API.

## How was this patch tested?

The existing PySpark accumulator tests (both unit tests and doc tests at the 
start of accumulator.py).

Author: Holden Karau <hol...@us.ibm.com>

Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/90d57542
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/90d57542
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/90d57542

Branch: refs/heads/master
Commit: 90d5754212425d55f992c939a2bc7d9ac6ef92b8
Parents: 5c5396c
Author: Holden Karau <hol...@us.ibm.com>
Authored: Fri Sep 23 09:44:30 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Sep 23 09:44:30 2016 +0100

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 42 +++++++++++---------
 python/pyspark/context.py                       |  5 +--
 2 files changed, 25 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/90d57542/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d841091..0ca91b9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -20,7 +20,7 @@ package org.apache.spark.api.python
 import java.io._
 import java.net._
 import java.nio.charset.StandardCharsets
-import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => 
JMap}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util._
 
 
 private[spark] class PythonRDD(
@@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
     pythonExec: String,
     pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]])
+    accumulator: PythonAccumulatorV2)
 
 /**
  * A wrapper for chained Python functions (from bottom to top).
@@ -200,7 +200,7 @@ private[spark] class PythonRunner(
                 val updateLen = stream.readInt()
                 val update = new Array[Byte](updateLen)
                 stream.readFully(update)
-                accumulator += Collections.singletonList(update)
+                accumulator.add(update)
               }
               // Check whether the worker is ready to be re-used.
               if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
@@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     try {
-      val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+      val objs = new mutable.ArrayBuffer[Array[Byte]]
       try {
         while (true) {
           val length = file.readInt()
@@ -866,11 +866,13 @@ class BytesToString extends 
org.apache.spark.api.java.function.Function[Array[By
 }
 
 /**
- * Internal class that acts as an `AccumulatorParam` for Python accumulators. 
Inside, it
+ * Internal class that acts as an `AccumulatorV2` for Python accumulators. 
Inside, it
  * collects a list of pickled strings that we pass to Python through a socket.
  */
-private class PythonAccumulatorParam(@transient private val serverHost: 
String, serverPort: Int)
-  extends AccumulatorParam[JList[Array[Byte]]] {
+private[spark] class PythonAccumulatorV2(
+    @transient private val serverHost: String,
+    private val serverPort: Int)
+  extends CollectionAccumulator[Array[Byte]] {
 
   Utils.checkHost(serverHost, "Expected hostname")
 
@@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private 
val serverHost: String,
    * We try to reuse a single Socket to transfer accumulator updates, as they 
are all added
    * by the DAGScheduler's single-threaded RpcEndpoint anyway.
    */
-  @transient var socket: Socket = _
+  @transient private var socket: Socket = _
 
-  def openSocket(): Socket = synchronized {
+  private def openSocket(): Socket = synchronized {
     if (socket == null || socket.isClosed) {
       socket = new Socket(serverHost, serverPort)
     }
     socket
   }
 
-  override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new 
JArrayList
+  // Need to override so the types match with PythonFunction
+  override def copyAndReset(): PythonAccumulatorV2 = new 
PythonAccumulatorV2(serverHost, serverPort)
 
-  override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
-      : JList[Array[Byte]] = synchronized {
+  override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): 
Unit = synchronized {
+    val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
+    // This conditional isn't strictly speaking needed - merging only 
currently happens on the
+    // driver program - but that isn't gauranteed so incase this changes.
     if (serverHost == null) {
-      // This happens on the worker node, where we just want to remember all 
the updates
-      val1.addAll(val2)
-      val1
+      // We are on the worker
+      super.merge(otherPythonAccumulator)
     } else {
       // This happens on the master, where we pass the updates to Python 
through a socket
       val socket = openSocket()
       val in = socket.getInputStream
       val out = new DataOutputStream(new 
BufferedOutputStream(socket.getOutputStream, bufferSize))
-      out.writeInt(val2.size)
-      for (array <- val2.asScala) {
+      val values = other.value
+      out.writeInt(values.size)
+      for (array <- values.asScala) {
         out.writeInt(array.length)
         out.write(array)
       }
@@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val 
serverHost: String,
       if (byteRead == -1) {
         throw new SparkException("EOF reached before Python server 
acknowledged")
       }
-      null
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/90d57542/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 7a7f59c..a3dd195 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -173,9 +173,8 @@ class SparkContext(object):
         # they will be passed back to us through a TCP server
         self._accumulatorServer = accumulators._start_update_server()
         (host, port) = self._accumulatorServer.server_address
-        self._javaAccumulator = self._jsc.accumulator(
-            self._jvm.java.util.ArrayList(),
-            self._jvm.PythonAccumulatorParam(host, port))
+        self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
+        self._jsc.sc().register(self._javaAccumulator)
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         self.pythonVer = "%d.%d" % sys.version_info[:2]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to