[SPARKR] Match pyspark features in SparkR communication protocol
Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ef361682 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ef361682 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ef361682 Branch: refs/heads/branch-2.2 Commit: ef36168258b8ad15362312e0562794f4f07322d0 Parents: 8ad6693 Author: hyukjinkwon <gurwls...@apache.org> Authored: Mon Sep 24 19:25:02 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Wed Sep 26 10:50:46 2018 +0800 ---------------------------------------------------------------------- R/pkg/R/context.R | 43 ++++++++++++++------ R/pkg/tests/fulltests/test_Serde.R | 32 +++++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 12 ------ .../scala/org/apache/spark/api/r/RRDD.scala | 33 ++++++++++++++- .../scala/org/apache/spark/api/r/RUtils.scala | 4 ++ 5 files changed, 98 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/R/context.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 50856e3..c1a12f5 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -168,18 +168,30 @@ parallelize <- function(sc, coll, numSlices = 1) { # 2-tuples of raws serializedSlices <- lapply(slices, serialize, connection = NULL) - # The PRC backend cannot handle arguments larger than 2GB (INT_MAX) + # The RPC backend cannot handle arguments larger than 2GB (INT_MAX) # If serialized data is safely less than that threshold we send it over the PRC channel. # Otherwise, we write it to a file and send the file name if (objectSize < sizeLimit) { jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) } else { - fileName <- writeToTempFile(serializedSlices) - jrdd <- tryCatch(callJStatic( - "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), - finally = { - file.remove(fileName) - }) + if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) { + # the length of slices here is the parallelism to use in the jvm's sc.parallelize() + parallelism <- as.integer(numSlices) + jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism) + authSecret <- callJMethod(jserver, "secret") + port <- callJMethod(jserver, "port") + conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) + doServerAuth(conn, authSecret) + writeToConnection(serializedSlices, conn) + jrdd <- callJMethod(jserver, "getResult") + } else { + fileName <- writeToTempFile(serializedSlices) + jrdd <- tryCatch(callJStatic( + "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), + finally = { + file.remove(fileName) + }) + } } RDD(jrdd, "byte") @@ -195,14 +207,21 @@ getMaxAllocationLimit <- function(sc) { )) } +writeToConnection <- function(serializedSlices, conn) { + tryCatch({ + for (slice in serializedSlices) { + writeBin(as.integer(length(slice)), conn, endian = "big") + writeBin(slice, conn, endian = "big") + } + }, finally = { + close(conn) + }) +} + writeToTempFile <- function(serializedSlices) { fileName <- tempfile() conn <- file(fileName, "wb") - for (slice in serializedSlices) { - writeBin(as.integer(length(slice)), conn, endian = "big") - writeBin(slice, conn, endian = "big") - } - close(conn) + writeToConnection(serializedSlices, conn) fileName } http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/tests/fulltests/test_Serde.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6bbd201..092f9b8 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -77,3 +77,35 @@ test_that("SerDe of list of lists", { }) sparkR.session.stop() + +# Note that this test should be at the end of tests since the configruations used here are not +# specific to sessions, and the Spark context is restarted. +test_that("createDataFrame large objects", { + for (encryptionEnabled in list("true", "false")) { + # To simulate a large object scenario, we set spark.r.maxAllocationLimit to a smaller value + conf <- list(spark.r.maxAllocationLimit = "100", + spark.io.encryption.enabled = encryptionEnabled) + + suppressWarnings(sparkR.session(master = sparkRTestMaster, + sparkConfig = conf, + enableHiveSupport = FALSE)) + + sc <- getSparkContext() + actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc) + expected <- as.logical(encryptionEnabled) + expect_equal(actual, expected) + + tryCatch({ + # suppress warnings from dot in the field names. See also SPARK-21536. + df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) + expect_equal(getNumPartitions(df), 3) + expect_equal(dim(df), dim(iris)) + + df <- createDataFrame(cars, numPartitions = 3) + expect_equal(collect(df), cars) + }, + finally = { + sparkR.stop() + }) + } +}) http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/tests/fulltests/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index f774554..f2b1c1d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -298,18 +298,6 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("createDataFrame uses files for large objects", { - # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value - conf <- callJMethod(sparkSession, "conf") - callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") - df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) - expect_equal(getNumPartitions(df), 3) - - # Resetting the conf back to default value - callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) - expect_equal(dim(df), dim(iris)) -}) - test_that("read/write csv as DataFrame", { if (windows_with_hadoop()) { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/core/src/main/scala/org/apache/spark/api/r/RRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 295355c..1dc61c7 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,7 +17,9 @@ package org.apache.spark.api.r -import java.io.File +import java.io.{DataInputStream, File} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -25,10 +27,11 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, PythonServer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -163,3 +166,29 @@ private[r] object RRDD { PythonRDD.readRDDFromFile(jsc, fileName, parallelism) } } + +/** + * Helper for making RDD[Array[Byte]] from some R data, by reading the data from R + * over a socket. This is used in preference to writing data to a file when encryption is enabled. + */ +private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int) + extends PythonServer[JavaRDD[Array[Byte]]]( + new RSocketAuthHelper(), "sparkr-parallelize-server") { + + override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { + val in = sock.getInputStream() + PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism) + } +} + +private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) { + override protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + // The R code adds a null terminator to serialized strings, so ignore it here. + assert(bytes(bytes.length - 1) == 0) // sanity check. + new String(bytes, 0, bytes.length - 1, UTF_8) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/core/src/main/scala/org/apache/spark/api/r/RUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index fdd8cf6..9bf35af 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -21,6 +21,8 @@ import java.io.File import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.api.python.PythonUtils private[spark] object RUtils { // Local path where R binary packages built from R source code contained in the spark @@ -104,4 +106,6 @@ private[spark] object RUtils { case e: Exception => false } } + + def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org