[PYSPARK] Updates to pyspark broadcast
Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/09dd34cb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/09dd34cb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/09dd34cb Branch: refs/heads/branch-2.3 Commit: 09dd34cb1706f2477a89174d6a1a0f17ed5b0a65 Parents: a2a54a5 Author: Imran Rashid <iras...@cloudera.com> Authored: Mon Aug 13 21:35:34 2018 -0500 Committer: Imran Rashid <iras...@cloudera.com> Committed: Thu Sep 13 09:19:56 2018 -0500 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 297 ++++++++++++++++--- .../apache/spark/api/python/PythonRunner.scala | 52 +++- .../spark/api/python/PythonRDDSuite.scala | 23 +- dev/sparktestsupport/modules.py | 2 + python/pyspark/broadcast.py | 58 +++- python/pyspark/context.py | 49 ++- python/pyspark/serializers.py | 58 ++++ python/pyspark/test_broadcast.py | 126 ++++++++ python/pyspark/test_serializers.py | 90 ++++++ python/pyspark/worker.py | 24 +- 10 files changed, 695 insertions(+), 84 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8bc0ff7..5e6bd96 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.Promise +import scala.concurrent.duration.Duration import scala.language.existentials -import scala.util.control.NonFatal +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -168,27 +171,34 @@ private[spark] object PythonRDD extends Logging { def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) + readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism) + } + + def readRDDFromInputStream( + sc: SparkContext, + in: InputStream, + parallelism: Int): JavaRDD[Array[Byte]] = { + val din = new DataInputStream(in) try { val objs = new mutable.ArrayBuffer[Array[Byte]] try { while (true) { - val length = file.readInt() + val length = din.readInt() val obj = new Array[Byte](length) - file.readFully(obj) + din.readFully(obj) objs += obj } } catch { case eof: EOFException => // No-op } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + JavaRDD.fromRDD(sc.parallelize(objs, parallelism)) } finally { - file.close() + din.close() } } - def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { - sc.broadcast(new PythonBroadcast(path)) + def setupBroadcast(path: String): PythonBroadcast = { + new PythonBroadcast(path) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -398,34 +408,15 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) - - new Thread(threadName) { - setDaemon(true) - override def run() { - try { - val sock = serverSocket.accept() - authHelper.authClient(sock) - - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) - } { - out.close() - sock.close() - } - } catch { - case NonFatal(e) => - logError(s"Error while sending iterator", e) - } finally { - serverSocket.close() - } + val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => + val out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream())) + Utils.tryWithSafeFinally { + writeIteratorToStream(items, out) + } { + out.close() } - }.start() - - Array(serverSocket.getLocalPort, authHelper.secret) + } + Array(port, secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -643,13 +634,11 @@ private[spark] class PythonAccumulatorV2( } } -/** - * A Wrapper for Python Broadcast, which is written into disk by Python. It also will - * write the data into disk after deserialization, then Python can read it from disks. - */ // scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable - with Logging { + with Logging { + + private var encryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -692,5 +681,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } super.finalize() } + + def setupEncryptionServer(): Array[Any] = { + encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") { + override def handleConnection(sock: Socket): Unit = { + val env = SparkEnv.get + val in = sock.getInputStream() + val dir = new File(Utils.getLocalDir(env.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + DechunkedInputStream.dechunkAndCopyToOutput(in, out) + } + } + Array(encryptionServer.port, encryptionServer.secret) + } + + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize + +/** + * The inverse of pyspark's ChunkedStream for sending broadcast data. + * Tested from python tests. + */ +private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging { + private val din = new DataInputStream(wrapped) + private var remainingInChunk = din.readInt() + + override def read(): Int = { + val into = new Array[Byte](1) + val n = read(into, 0, 1) + if (n == -1) { + -1 + } else { + // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted + // as an EOF + into(0) & 0xFF + } + } + + override def read(dest: Array[Byte], off: Int, len: Int): Int = { + if (remainingInChunk == -1) { + return -1 + } + var destSpace = len + var destPos = off + while (destSpace > 0 && remainingInChunk != -1) { + val toCopy = math.min(remainingInChunk, destSpace) + val read = din.read(dest, destPos, toCopy) + destPos += read + destSpace -= read + remainingInChunk -= read + if (remainingInChunk == 0) { + remainingInChunk = din.readInt() + } + } + assert(destSpace == 0 || remainingInChunk == -1) + return destPos - off + } + + override def close(): Unit = wrapped.close() +} + +/** + * The inverse of pyspark's ChunkedStream for sending data of unknown size. + * + * We might be serializing a really large object from python -- we don't want + * python to buffer the whole thing in memory, nor can it write to a file, + * so we don't know the length in advance. So python writes it in chunks, each chunk + * preceeded by a length, till we get a "length" of -1 which serves as EOF. + * + * Tested from python tests. + */ +private[spark] object DechunkedInputStream { + + /** + * Dechunks the input, copies to output, and closes both input and the output safely. + */ + def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = { + val dechunked = new DechunkedInputStream(chunked) + Utils.tryWithSafeFinally { + Utils.copyStream(dechunked, out) + } { + JavaUtils.closeQuietly(out) + JavaUtils.closeQuietly(dechunked) + } + } +} + +/** + * Creates a server in the jvm to communicate with python for handling one batch of data, with + * authentication and error handling. + */ +private[spark] abstract class PythonServer[T]( + authHelper: SocketAuthHelper, + threadName: String) { + + def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) + def this(threadName: String) = this(SparkEnv.get, threadName) + + val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock => + promise.complete(Try(handleConnection(sock))) + } + + /** + * Handle a connection which has already been authenticated. Any error from this function + * will clean up this connection and the entire server, and get propogated to [[getResult]]. + */ + def handleConnection(sock: Socket): T + + val promise = Promise[T]() + + /** + * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If + * handleConnection throws an exception, this will throw an exception which includes the original + * exception as a cause. + */ + def getResult(): T = { + getResult(Duration.Inf) + } + + def getResult(wait: Duration): T = { + ThreadUtils.awaitResult(promise.future, wait) + } + +} + +private[spark] object PythonServer { + + /** + * Create a socket server and run user function on the socket in a background thread. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * The thread will terminate after the supplied user function, or if there are any exceptions. + * + * If you need to get a result of the supplied function, create a subclass of [[PythonServer]] + * + * @return The port number of a local socket and the secret for authentication. + */ + def setupOneConnectionServer( + authHelper: SocketAuthHelper, + threadName: String) + (func: Socket => Unit): (Int, String) = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + // Close the socket if no connection in 15 seconds + serverSocket.setSoTimeout(15000) + + new Thread(threadName) { + setDaemon(true) + override def run(): Unit = { + var sock: Socket = null + try { + sock = serverSocket.accept() + authHelper.authClient(sock) + func(sock) + } finally { + JavaUtils.closeQuietly(serverSocket) + JavaUtils.closeQuietly(sock) + } + } + }.start() + (serverSocket.getLocalPort, authHelper.secret) + } +} + +/** + * Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol. + */ +private[spark] class EncryptedPythonBroadcastServer( + val env: SparkEnv, + val idsAndFiles: Seq[(Long, String)]) + extends PythonServer[Unit]("broadcast-decrypt-server") with Logging { + + override def handleConnection(socket: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream())) + var socketIn: InputStream = null + // send the broadcast id, then the decrypted data. We don't need to send the length, the + // the python pickle module just needs a stream. + Utils.tryWithSafeFinally { + (idsAndFiles).foreach { case (id, path) => + out.writeLong(id) + val in = env.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + } + logTrace("waiting for python to accept broadcast data over socket") + out.flush() + socketIn = socket.getInputStream() + socketIn.read() + logTrace("done serving broadcast data") + } { + JavaUtils.closeQuietly(socketIn) + JavaUtils.closeQuietly(out) + } + } + + def waitTillBroadcastDataSent(): Unit = { + getResult() + } +} + +/** + * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python + * over a socket. This is used in preference to writing data to a file when encryption is enabled. + */ +private[spark] abstract class PythonRDDServer + extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") { + + def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { + val in = sock.getInputStream() + val dechunkedInput: InputStream = new DechunkedInputStream(in) + streamToRDD(dechunkedInput) + } + + protected def streamToRDD(input: InputStream): RDD[Array[Byte]] + +} + +private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int) + extends PythonRDDServer { + + override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { + PythonRDD.readRDDFromInputStream(sc, input, parallelism) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 719ce5b..754a654 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -193,19 +193,51 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size + val addedBids = newBids.diff(oldBids) + val cnt = toRemove.size + addedBids.size + val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty + dataOut.writeBoolean(needsDecryptionServer) dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) + def sendBidsToRemove(): Unit = { + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(-bid - 1) // bid >= 0 + oldBids.remove(bid) + } } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { + if (needsDecryptionServer) { + // if there is encryption, we setup a server which reads the encrypted files, and sends + // the decrypted data to python + val idsAndFiles = broadcastVars.flatMap { broadcast => + if (oldBids.contains(broadcast.id)) { + None + } else { + Some((broadcast.id, broadcast.value.path)) + } + } + val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) + dataOut.writeInt(server.port) + logTrace(s"broadcast decryption server setup on ${server.port}") + PythonRDD.writeUTF(server.secret, dataOut) + sendBidsToRemove() + idsAndFiles.foreach { case (id, _) => // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) + dataOut.writeLong(id) + oldBids.add(id) + } + dataOut.flush() + logTrace("waiting for python to read decrypted broadcast data from server") + server.waitTillBroadcastDataSent() + logTrace("done sending decrypted data to python") + } else { + sendBidsToRemove() + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } } } dataOut.flush() http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 05b4e67..6f9b583 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.net.{InetAddress, Socket} import java.nio.charset.StandardCharsets -import org.apache.spark.SparkFunSuite +import scala.concurrent.duration.Duration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.security.SocketAuthHelper class PythonRDDSuite extends SparkFunSuite { @@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite { ("a".getBytes(StandardCharsets.UTF_8), null), (null, "b".getBytes(StandardCharsets.UTF_8))), buffer) } + + test("python server error handling") { + val authHelper = new SocketAuthHelper(new SparkConf()) + val errorServer = new ExceptionPythonServer(authHelper) + val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) + authHelper.authToServer(client) + val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } + assert(ex.getCause().getMessage().contains("exception within handleConnection")) + } + + class ExceptionPythonServer(authHelper: SocketAuthHelper) + extends PythonServer[Unit](authHelper, "error-server") { + + override def handleConnection(sock: Socket): Unit = { + throw new Exception("exception within handleConnection") + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/dev/sparktestsupport/modules.py ---------------------------------------------------------------------- diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b900f0b..d0bff13 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -377,6 +377,8 @@ pyspark_core = Module( "pyspark.profiler", "pyspark.shuffle", "pyspark.tests", + "pyspark.test_broadcast", + "pyspark.test_serializers", "pyspark.util", ] ) http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515..3f1298e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,13 +15,16 @@ # limitations under the License. # +import gc import os +import socket import sys -import gc from tempfile import NamedTemporaryFile import threading from pyspark.cloudpickle import print_exec +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ChunkedStream from pyspark.util import _exception_message if sys.version < '3': @@ -64,19 +67,43 @@ class Broadcast(object): >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, sc=None, value=None, pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None, + sock_file=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ if sc is not None: + # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) - self._path = self.dump(value, f) - self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._path = f.name + python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + if sc._encryption_enabled: + # with encryption, we ask the jvm to do the encryption for us, we send it data + # over a socket + port, auth_secret = python_broadcast.setupEncryptionServer() + (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) + broadcast_out = ChunkedStream(encryption_sock_file, 8192) + else: + # no encryption, we can just write pickled data directly to the file from python + broadcast_out = f + self.dump(value, broadcast_out) + if sc._encryption_enabled: + python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(python_broadcast) self._pickle_registry = pickle_registry else: + # we're on an executor self._jbroadcast = None - self._path = path + if sock_file is not None: + # the jvm is doing decryption for us. Read the value + # immediately from the sock_file + self._value = self.load(sock_file) + else: + # the jvm just dumps the pickled data in path -- we'll unpickle lazily when + # the value is requested + assert(path is not None) + self._path = path def dump(self, value, f): try: @@ -89,24 +116,25 @@ class Broadcast(object): print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() - return f.name - def load(self, path): + def load_from_path(self, path): with open(path, 'rb', 1 << 20) as f: - # pickle.load() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return pickle.load(f) - finally: - gc.enable() + return self.load(f) + + def load(self, file): + # "file" could also be a socket + gc.disable() + try: + return pickle.load(file) + finally: + gc.enable() @property def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load(self._path) + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 667d0a3..2aac7ba 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles -from pyspark.java_gateway import launch_gateway +from pyspark.java_gateway import launch_gateway, local_connect_and_auth from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call @@ -189,6 +189,13 @@ class SparkContext(object): self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) + # If encryption is enabled, we need to setup a server in the jvm to read broadcast + # data via a socket. + # scala's mangled names w/ $ in them require special treatment. + encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\ + .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED() + self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] @@ -499,19 +506,31 @@ class SparkContext(object): def _serialize_to_jvm(self, data, parallelism, serializer): """ - Calling the Java parallelize() method with an ArrayList is too slow, - because it sends O(n) Py4J commands. As an alternative, serialized - objects are written to a file and loaded through textFile(). - """ - tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - try: - serializer.dump_stream(data, tempFile) - tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) - finally: - # readRDDFromFile eagerily reads the file so we can delete right after. - os.unlink(tempFile.name) + Using py4j to send a large dataset to the jvm is really slow, so we use either a file + or a socket if we have encryption enabled. + """ + if self._encryption_enabled: + # with encryption, we open a server in java and send the data directly + server = self._jvm.PythonParallelizeServer(self._jsc.sc(), parallelism) + (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) + chunked_out = ChunkedStream(sock_file, 8192) + serializer.dump_stream(data, chunked_out) + chunked_out.close() + # this call will block until the server has read all the data and processed it (or + # throws an exception) + return server.getResult() + else: + # without encryption, we serialize to a file, and we read the file in java and + # parallelize from there. + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + try: + serializer.dump_stream(data, tempFile) + tempFile.close() + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + return readRDDFromFile(self._jsc, tempFile.name, parallelism) + finally: + # we eagerly read the file so we can delete right after. + os.unlink(tempFile.name) def pickleFile(self, name, minPartitions=None): """ http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 52a7afe..3927d21 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -697,11 +697,69 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) +class ChunkedStream(object): + + """ + This file-like object takes a stream of data, of unknown length, and breaks it into fixed + length frames. The intended use case is serializing large data and sending it immediately over + a socket -- we do not want to buffer the entire data before sending it, but the receiving end + needs to know whether or not there is more data coming. + + It works by buffering the incoming data in some fixed-size chunks. If the buffer is full, it + first sends the buffer size, then the data. This repeats as long as there is more data to send. + When this is closed, it sends the length of whatever data is in the buffer, then that data, and + finally a "length" of -1 to indicate the stream has completed. + """ + + def __init__(self, wrapped, buffer_size): + self.buffer_size = buffer_size + self.buffer = bytearray(buffer_size) + self.current_pos = 0 + self.wrapped = wrapped + + def write(self, bytes): + byte_pos = 0 + byte_remaining = len(bytes) + while byte_remaining > 0: + new_pos = byte_remaining + self.current_pos + if new_pos < self.buffer_size: + # just put it in our buffer + self.buffer[self.current_pos:new_pos] = bytes[byte_pos:] + self.current_pos = new_pos + byte_remaining = 0 + else: + # fill the buffer, send the length then the contents, and start filling again + space_left = self.buffer_size - self.current_pos + new_byte_pos = byte_pos + space_left + self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos] + write_int(self.buffer_size, self.wrapped) + self.wrapped.write(self.buffer) + byte_remaining -= space_left + byte_pos = new_byte_pos + self.current_pos = 0 + + def close(self): + # if there is anything left in the buffer, write it out first + if self.current_pos > 0: + write_int(self.current_pos, self.wrapped) + self.wrapped.write(self.buffer[:self.current_pos]) + # -1 length indicates to the receiving end that we're done. + write_int(-1, self.wrapped) + self.wrapped.close() + + if __name__ == '__main__': import doctest (failure_count, test_count) = doctest.testmod() http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/test_broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py new file mode 100644 index 0000000..ce7ca83 --- /dev/null +++ b/python/pyspark/test_broadcast.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import random +import tempfile +import unittest + +try: + import xmlrunner +except ImportError: + xmlrunner = None + +from pyspark.broadcast import Broadcast +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import ChunkedStream + + +class BroadcastTest(unittest.TestCase): + + def tearDown(self): + if getattr(self, "sc", None) is not None: + self.sc.stop() + self.sc = None + + def _test_encryption_helper(self, vs): + """ + Creates a broadcast variables for each value in vs, and runs a simple job to make sure the + value is the same when it's read in the executors. Also makes sure there are no task + failures. + """ + bs = [self.sc.broadcast(value=v) for v in vs] + exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect() + for ev in exec_values: + self.assertEqual(ev, vs) + # make sure there are no task failures + status = self.sc.statusTracker() + for jid in status.getJobIdsForGroup(): + for sid in status.getJobInfo(jid).stageIds: + stage_info = status.getStageInfo(sid) + self.assertEqual(0, stage_info.numFailedTasks) + + def _test_multiple_broadcasts(self, *extra_confs): + """ + Test broadcast variables make it OK to the executors. Tests multiple broadcast variables, + and also multiple jobs. + """ + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + self._test_encryption_helper([5]) + self._test_encryption_helper([5, 10, 20]) + + def test_broadcast_with_encryption(self): + self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true")) + + def test_broadcast_no_encryption(self): + self._test_multiple_broadcasts() + + +class BroadcastFrameProtocolTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + gateway = launch_gateway(SparkConf()) + cls._jvm = gateway.jvm + cls.longMessage = True + random.seed(42) + + def _test_chunked_stream(self, data, py_buf_size): + # write data using the chunked protocol from python. + chunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file.close() + try: + out = ChunkedStream(chunked_file, py_buf_size) + out.write(data) + out.close() + # now try to read it in java + jin = self._jvm.java.io.FileInputStream(chunked_file.name) + jout = self._jvm.java.io.FileOutputStream(dechunked_file.name) + self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout) + # java should have decoded it back to the original data + self.assertEqual(len(data), os.stat(dechunked_file.name).st_size) + with open(dechunked_file.name, "rb") as f: + byte = f.read(1) + idx = 0 + while byte: + self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx)) + byte = f.read(1) + idx += 1 + finally: + os.unlink(chunked_file.name) + os.unlink(dechunked_file.name) + + def test_chunked_stream(self): + def random_bytes(n): + return bytearray(random.getrandbits(8) for _ in range(n)) + for data_length in [1, 10, 100, 10000]: + for buffer_length in [1, 2, 5, 8192]: + self._test_chunked_stream(random_bytes(data_length), buffer_length) + +if __name__ == '__main__': + from pyspark.test_broadcast import * + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/test_serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py new file mode 100644 index 0000000..5064e9f --- /dev/null +++ b/python/pyspark/test_serializers.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import io +import math +import struct +import sys +import unittest + +try: + import xmlrunner +except ImportError: + xmlrunner = None + +from pyspark import serializers + + +def read_int(b): + return struct.unpack("!i", b)[0] + + +def write_int(i): + return struct.pack("!i", i) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +class ByteArrayOutput(object): + def __init__(self): + self.buffer = bytearray() + + def write(self, b): + self.buffer += b + + def close(self): + pass + +if __name__ == '__main__': + from pyspark.test_serializers import * + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 812f4b2..942a7f3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type @@ -206,16 +206,34 @@ def main(infile, outfile): importlib.invalidate_caches() # fetch names and values of broadcast variables + needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) + if needs_broadcast_decryption_server: + # read the decrypted data from a server in the jvm + port = read_int(infile) + auth_secret = utf8_deserializer.loads(infile) + (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) + for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - path = utf8_deserializer.loads(infile) - _broadcastRegistry[bid] = Broadcast(path=path) + if needs_broadcast_decryption_server: + read_bid = read_long(broadcast_sock_file) + assert(read_bid == bid) + _broadcastRegistry[bid] = \ + Broadcast(sock_file=broadcast_sock_file) + else: + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) + else: bid = - bid - 1 _broadcastRegistry.pop(bid) + if needs_broadcast_decryption_server: + broadcast_sock_file.write(b'1') + broadcast_sock_file.close() + _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org