[PYSPARK][SQL] Updates to RowQueue Tested with updates to RowQueueSuite
Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d742d1b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d742d1b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d742d1b Branch: refs/heads/branch-2.3 Commit: 6d742d1bd71aa3803dce91a830b37284cb18cf70 Parents: 09dd34c Author: Imran Rashid <iras...@cloudera.com> Authored: Thu Sep 6 12:11:47 2018 -0500 Committer: Imran Rashid <iras...@cloudera.com> Committed: Thu Sep 13 09:19:56 2018 -0500 ---------------------------------------------------------------------- .../spark/sql/execution/python/RowQueue.scala | 27 ++++++++++++++----- .../sql/execution/python/RowQueueSuite.scala | 28 +++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index e2fa6e7..d2820ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -21,9 +21,10 @@ import java.io._ import com.google.common.io.Closeables -import org.apache.spark.SparkException +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryBlock @@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any * reader has begun reading from the queue. */ -private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue { - private var out = new DataOutputStream( - new BufferedOutputStream(new FileOutputStream(file.toString))) +private[python] case class DiskRowQueue( + file: File, + fields: Int, + serMgr: SerializerManager) extends RowQueue { + + private var out = new DataOutputStream(serMgr.wrapForEncryption( + new BufferedOutputStream(new FileOutputStream(file.toString)))) private var unreadBytes = 0L private var in: DataInputStream = _ @@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu if (out != null) { out.close() out = null - in = new DataInputStream(new NioBufferedFileInputStream(file)) + in = new DataInputStream(serMgr.wrapForEncryption( + new NioBufferedFileInputStream(file))) } if (unreadBytes > 0) { @@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu private[python] case class HybridRowQueue( memManager: TaskMemoryManager, tempDir: File, - numFields: Int) + numFields: Int, + serMgr: SerializerManager) extends MemoryConsumer(memManager) with RowQueue { // Each buffer should have at least one row @@ -212,7 +219,7 @@ private[python] case class HybridRowQueue( } private def createDiskQueue(): RowQueue = { - DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields) + DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr) } private def createNewQueue(required: Long): RowQueue = { @@ -279,3 +286,9 @@ private[python] case class HybridRowQueue( } } } + +private[python] object HybridRowQueue { + def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = { + HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33c..1ec9986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.internal.config._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Utils -class RowQueueSuite extends SparkFunSuite { +class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { test("in-memory queue") { val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) @@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite { queue.close() } - test("disk queue") { + private def createSerializerManager(conf: SparkConf): SerializerManager = { + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } + new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey) + } + + encryptionTest("disk queue") { conf => + val serManager = createSerializerManager(conf) val dir = Utils.createTempDir().getCanonicalFile dir.mkdirs() - val queue = DiskRowQueue(new File(dir, "buffer"), 1) + val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager) val row = new UnsafeRow(1) row.pointTo(new Array[Byte](16), 16) val n = 1000 @@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite { queue.close() } - test("hybrid queue") { - val mem = new TestMemoryManager(new SparkConf()) + encryptionTest("hybrid queue") { conf => + val serManager = createSerializerManager(conf) + val mem = new TestMemoryManager(conf) mem.limit(4<<10) val taskM = new TaskMemoryManager(mem, 0) - val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1) + val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager) val row = new UnsafeRow(1) row.pointTo(new Array[Byte](16), 16) val n = (4<<10) / 16 * 3 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org