Repository: spark Updated Branches: refs/heads/branch-2.3 e96ba8430 -> 4ee463ac9
[SPARK-26201] Fix python broadcast with encryption ## What changes were proposed in this pull request? Python with rpc and disk encryption enabled along with a python broadcast variable and just read the value back on the driver side the job failed with: Traceback (most recent call last): File "broadcast.py", line 37, in <module> words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File "pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input To reproduce use configs: --conf spark.network.crypto.enabled=true --conf spark.io.encryption.enabled=true Code: words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"]) words_new.value print(words_new.value) ## How was this patch tested? words_new = sc.broadcast([âscalaâ, âjavaâ, âhadoopâ, âsparkâ, âakkaâ]) textFile = sc.textFile(âREADME.mdâ) wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word + words_new.value[1], 1)).reduceByKey(lambda a, b: a+b) count = wordCounts.count() print(count) words_new.value print(words_new.value) Closes #23166 from redsanket/SPARK-26201. Authored-by: schintap <schin...@oath.com> Signed-off-by: Thomas Graves <tgra...@apache.org> (cherry picked from commit 9b23be2e95fec756066ca0ed3188c3db2602b757) Signed-off-by: Thomas Graves <tgra...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ee463ac Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ee463ac Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ee463ac Branch: refs/heads/branch-2.3 Commit: 4ee463ac96f343afb6f61072a427cc0d30dedbfd Parents: e96ba84 Author: schintap <schin...@oath.com> Authored: Fri Nov 30 12:48:56 2018 -0600 Committer: Thomas Graves <tgra...@apache.org> Committed: Fri Nov 30 12:49:30 2018 -0600 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 29 +++++++++++++++++--- python/pyspark/broadcast.py | 21 ++++++++++---- python/pyspark/test_broadcast.py | 15 ++++++++++ 3 files changed, 56 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4ee463ac/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 5e6bd96..edea25c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -639,6 +639,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial with Logging { private var encryptionServer: PythonServer[Unit] = null + private var decryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -687,16 +688,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() - val dir = new File(Utils.getLocalDir(env.conf)) - val file = File.createTempFile("broadcast", "", dir) - path = file.getAbsolutePath - val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + val abspath = new File(path).getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath)) DechunkedInputStream.dechunkAndCopyToOutput(in, out) } } Array(encryptionServer.port, encryptionServer.secret) } + def setupDecryptionServer(): Array[Any] = { + decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + override def handleConnection(sock: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) + Utils.tryWithSafeFinally { + val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + out.flush() + } { + JavaUtils.closeQuietly(out) + } + } + } + Array(decryptionServer.port, decryptionServer.secret) + } + + def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult() + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize http://git-wip-us.apache.org/repos/asf/spark/blob/4ee463ac/python/pyspark/broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 3f1298e..b526674 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -77,11 +77,12 @@ class Broadcast(object): # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + self._sc = sc + self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket - port, auth_secret = python_broadcast.setupEncryptionServer() + port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: @@ -89,12 +90,14 @@ class Broadcast(object): broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: - python_broadcast.waitTillDataReceived() - self._jbroadcast = sc._jsc.broadcast(python_broadcast) + self._python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None + self._sc = None + self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file @@ -134,7 +137,15 @@ class Broadcast(object): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load_from_path(self._path) + # we only need to decrypt it here when encryption is enabled and + # if its on the driver, since executor decryption is handled already + if self._sc is not None and self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): http://git-wip-us.apache.org/repos/asf/spark/blob/4ee463ac/python/pyspark/test_broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py index ce7ca83..1630d40 100644 --- a/python/pyspark/test_broadcast.py +++ b/python/pyspark/test_broadcast.py @@ -75,6 +75,21 @@ class BroadcastTest(unittest.TestCase): def test_broadcast_no_encryption(self): self._test_multiple_broadcasts() + def _test_broadcast_on_driver(self, *extra_confs): + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bs = self.sc.broadcast(value=5) + self.assertEqual(5, bs.value) + + def test_broadcast_value_driver_no_encryption(self): + self._test_broadcast_on_driver() + + def test_broadcast_value_driver_encryption(self): + self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + class BroadcastFrameProtocolTest(unittest.TestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org