Repository: spark Updated Branches: refs/heads/master 5c5396cb4 -> 90d575421
[SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2 ## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau <hol...@us.ibm.com> Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/90d57542 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/90d57542 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/90d57542 Branch: refs/heads/master Commit: 90d5754212425d55f992c939a2bc7d9ac6ef92b8 Parents: 5c5396c Author: Holden Karau <hol...@us.ibm.com> Authored: Fri Sep 23 09:44:30 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Fri Sep 23 09:44:30 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 42 +++++++++++--------- python/pyspark/context.py | 5 +-- 2 files changed, 25 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/90d57542/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d841091..0ca91b9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets -import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util._ private[spark] class PythonRDD( @@ -75,7 +75,7 @@ private[spark] case class PythonFunction( pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + accumulator: PythonAccumulatorV2) /** * A wrapper for chained Python functions (from bottom to top). @@ -200,7 +200,7 @@ private[spark] class PythonRunner( val updateLen = stream.readInt() val update = new Array[Byte](updateLen) stream.readFully(update) - accumulator += Collections.singletonList(update) + accumulator.add(update) } // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { @@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging { JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) try { - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + val objs = new mutable.ArrayBuffer[Array[Byte]] try { while (true) { val length = file.readInt() @@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By } /** - * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) - extends AccumulatorParam[JList[Array[Byte]]] { +private[spark] class PythonAccumulatorV2( + @transient private val serverHost: String, + private val serverPort: Int) + extends CollectionAccumulator[Array[Byte]] { Utils.checkHost(serverHost, "Expected hostname") @@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient var socket: Socket = _ + @transient private var socket: Socket = _ - def openSocket(): Socket = synchronized { + private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) } socket } - override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + // Need to override so the types match with PythonFunction + override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) - override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = synchronized { + override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { + val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] + // This conditional isn't strictly speaking needed - merging only currently happens on the + // driver program - but that isn't gauranteed so incase this changes. if (serverHost == null) { - // This happens on the worker node, where we just want to remember all the updates - val1.addAll(val2) - val1 + // We are on the worker + super.merge(otherPythonAccumulator) } else { // This happens on the master, where we pass the updates to Python through a socket val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) - out.writeInt(val2.size) - for (array <- val2.asScala) { + val values = other.value + out.writeInt(values.size) + for (array <- values.asScala) { out.writeInt(array.length) out.write(array) } @@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - null } } } http://git-wip-us.apache.org/repos/asf/spark/blob/90d57542/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7a7f59c..a3dd195 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,9 +173,8 @@ class SparkContext(object): # they will be passed back to us through a TCP server self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jsc.accumulator( - self._jvm.java.util.ArrayList(), - self._jvm.PythonAccumulatorParam(host, port)) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org