Repository: spark
Updated Branches:
  refs/heads/master c3f27b243 -> 9b23be2e9


[SPARK-26201] Fix python broadcast with encryption

## What changes were proposed in this pull request?
Python with rpc and disk encryption enabled along with a python broadcast 
variable and just read the value back on the driver side the job failed with:

Traceback (most recent call last): File "broadcast.py", line 37, in <module> 
words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value 
File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File 
"pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input

To reproduce use configs: --conf spark.network.crypto.enabled=true --conf 
spark.io.encryption.enabled=true

Code:

words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"])
words_new.value
print(words_new.value)

## How was this patch tested?
words_new = sc.broadcast([“scala”, “java”, “hadoop”, “spark”, 
“akka”])
textFile = sc.textFile(“README.md”)
wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word 
+ words_new.value[1], 1)).reduceByKey(lambda a, b: a+b)
 count = wordCounts.count()
 print(count)
 words_new.value
 print(words_new.value)

Closes #23166 from redsanket/SPARK-26201.

Authored-by: schintap <schin...@oath.com>
Signed-off-by: Thomas Graves <tgra...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b23be2e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b23be2e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b23be2e

Branch: refs/heads/master
Commit: 9b23be2e95fec756066ca0ed3188c3db2602b757
Parents: c3f27b2
Author: schintap <schin...@oath.com>
Authored: Fri Nov 30 12:48:56 2018 -0600
Committer: Thomas Graves <tgra...@apache.org>
Committed: Fri Nov 30 12:48:56 2018 -0600

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 29 +++++++++++++++++---
 python/pyspark/broadcast.py                     | 21 ++++++++++----
 python/pyspark/tests/test_broadcast.py          | 15 ++++++++++
 3 files changed, 56 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b23be2e/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 8b5a7a9..5ed5070 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: 
String) extends Serial
     with Logging {
 
   private var encryptionServer: PythonServer[Unit] = null
+  private var decryptionServer: PythonServer[Unit] = null
 
   /**
    * Read data from disks, then copy it to `out`
@@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: 
String) extends Serial
       override def handleConnection(sock: Socket): Unit = {
         val env = SparkEnv.get
         val in = sock.getInputStream()
-        val dir = new File(Utils.getLocalDir(env.conf))
-        val file = File.createTempFile("broadcast", "", dir)
-        path = file.getAbsolutePath
-        val out = env.serializerManager.wrapForEncryption(new 
FileOutputStream(path))
+        val abspath = new File(path).getAbsolutePath
+        val out = env.serializerManager.wrapForEncryption(new 
FileOutputStream(abspath))
         DechunkedInputStream.dechunkAndCopyToOutput(in, out)
       }
     }
     Array(encryptionServer.port, encryptionServer.secret)
   }
 
+  def setupDecryptionServer(): Array[Any] = {
+    decryptionServer = new 
PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
+      override def handleConnection(sock: Socket): Unit = {
+        val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream()))
+        Utils.tryWithSafeFinally {
+          val in = SparkEnv.get.serializerManager.wrapForEncryption(new 
FileInputStream(path))
+          Utils.tryWithSafeFinally {
+            Utils.copyStream(in, out, false)
+          } {
+            in.close()
+          }
+          out.flush()
+        } {
+          JavaUtils.closeQuietly(out)
+        }
+      }
+    }
+    Array(decryptionServer.port, decryptionServer.secret)
+  }
+
+  def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
+
   def waitTillDataReceived(): Unit = encryptionServer.getResult()
 }
 // scalastyle:on no.finalize

http://git-wip-us.apache.org/repos/asf/spark/blob/9b23be2e/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 1c7f2a7..29358b5 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -77,11 +77,12 @@ class Broadcast(object):
             # we're on the driver.  We want the pickled data to end up in a 
file (maybe encrypted)
             f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
             self._path = f.name
-            python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+            self._sc = sc
+            self._python_broadcast = 
sc._jvm.PythonRDD.setupBroadcast(self._path)
             if sc._encryption_enabled:
                 # with encryption, we ask the jvm to do the encryption for us, 
we send it data
                 # over a socket
-                port, auth_secret = python_broadcast.setupEncryptionServer()
+                port, auth_secret = 
self._python_broadcast.setupEncryptionServer()
                 (encryption_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
                 broadcast_out = ChunkedStream(encryption_sock_file, 8192)
             else:
@@ -89,12 +90,14 @@ class Broadcast(object):
                 broadcast_out = f
             self.dump(value, broadcast_out)
             if sc._encryption_enabled:
-                python_broadcast.waitTillDataReceived()
-            self._jbroadcast = sc._jsc.broadcast(python_broadcast)
+                self._python_broadcast.waitTillDataReceived()
+            self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
             self._pickle_registry = pickle_registry
         else:
             # we're on an executor
             self._jbroadcast = None
+            self._sc = None
+            self._python_broadcast = None
             if sock_file is not None:
                 # the jvm is doing decryption for us.  Read the value
                 # immediately from the sock_file
@@ -134,7 +137,15 @@ class Broadcast(object):
         """ Return the broadcasted value
         """
         if not hasattr(self, "_value") and self._path is not None:
-            self._value = self.load_from_path(self._path)
+            # we only need to decrypt it here when encryption is enabled and
+            # if its on the driver, since executor decryption is handled 
already
+            if self._sc is not None and self._sc._encryption_enabled:
+                port, auth_secret = 
self._python_broadcast.setupDecryptionServer()
+                (decrypted_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
+                self._python_broadcast.waitTillBroadcastDataSent()
+                return self.load(decrypted_sock_file)
+            else:
+                self._value = self.load_from_path(self._path)
         return self._value
 
     def unpersist(self, blocking=False):

http://git-wip-us.apache.org/repos/asf/spark/blob/9b23be2e/python/pyspark/tests/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_broadcast.py 
b/python/pyspark/tests/test_broadcast.py
index a98626e..11d31d2 100644
--- a/python/pyspark/tests/test_broadcast.py
+++ b/python/pyspark/tests/test_broadcast.py
@@ -67,6 +67,21 @@ class BroadcastTest(unittest.TestCase):
     def test_broadcast_no_encryption(self):
         self._test_multiple_broadcasts()
 
+    def _test_broadcast_on_driver(self, *extra_confs):
+        conf = SparkConf()
+        for key, value in extra_confs:
+            conf.set(key, value)
+        conf.setMaster("local-cluster[2,1,1024]")
+        self.sc = SparkContext(conf=conf)
+        bs = self.sc.broadcast(value=5)
+        self.assertEqual(5, bs.value)
+
+    def test_broadcast_value_driver_no_encryption(self):
+        self._test_broadcast_on_driver()
+
+    def test_broadcast_value_driver_encryption(self):
+        self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
+
 
 class BroadcastFrameProtocolTest(unittest.TestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to