Repository: spark
Updated Branches:
  refs/heads/master 16a2be1a8 -> 71a138cd0


[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.

This PR:
1. supports transferring arbitrary nested array from JVM to R side in SerDe;
2. based on 1, collect() implemenation is improved. Now it can support 
collecting data of complex types
   from a DataFrame.

Author: Sun Rui <rui....@intel.com>

Closes #8276 from sun-rui/SPARK-10048.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/71a138cd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/71a138cd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/71a138cd

Branch: refs/heads/master
Commit: 71a138cd0e0a14e8426f97877e3b52a562bbd02c
Parents: 16a2be1
Author: Sun Rui <rui....@intel.com>
Authored: Tue Aug 25 13:14:10 2015 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Tue Aug 25 13:14:10 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/DataFrame.R                             | 55 ++++++++++---
 R/pkg/R/deserialize.R                           | 72 +++++++---------
 R/pkg/R/serialize.R                             | 10 +--
 R/pkg/inst/tests/test_Serde.R                   | 77 ++++++++++++++++++
 R/pkg/inst/worker/worker.R                      |  4 +-
 .../apache/spark/api/r/RBackendHandler.scala    |  7 ++
 .../scala/org/apache/spark/api/r/SerDe.scala    | 86 ++++++++++++--------
 .../org/apache/spark/sql/api/r/SQLUtils.scala   | 32 +-------
 8 files changed, 216 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 10f3c4e..ae1d912 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -652,18 +652,49 @@ setMethod("dim",
 setMethod("collect",
           signature(x = "DataFrame"),
           function(x, stringsAsFactors = FALSE) {
-            # listCols is a list of raw vectors, one per column
-            listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"dfToCols", x@sdf)
-            cols <- lapply(listCols, function(col) {
-              objRaw <- rawConnection(col)
-              numRows <- readInt(objRaw)
-              col <- readCol(objRaw, numRows)
-              close(objRaw)
-              col
-            })
-            names(cols) <- columns(x)
-            do.call(cbind.data.frame, list(cols, stringsAsFactors = 
stringsAsFactors))
-          })
+            names <- columns(x)
+            ncol <- length(names)
+            if (ncol <= 0) {
+              # empty data.frame with 0 columns and 0 rows
+              data.frame()
+            } else {
+              # listCols is a list of columns
+              listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"dfToCols", x@sdf)
+              stopifnot(length(listCols) == ncol)
+              
+              # An empty data.frame with 0 columns and number of rows as 
collected
+              nrow <- length(listCols[[1]])
+              if (nrow <= 0) {
+                df <- data.frame()
+              } else {
+                df <- data.frame(row.names = 1 : nrow)                
+              }
+              
+              # Append columns one by one
+              for (colIndex in 1 : ncol) {
+                # Note: appending a column of list type into a data.frame so 
that
+                # data of complex type can be held. But getting a cell from a 
column
+                # of list type returns a list instead of a vector. So for 
columns of
+                # non-complex type, append them as vector.
+                col <- listCols[[colIndex]]
+                if (length(col) <= 0) {
+                  df[[names[colIndex]]] <- col
+                } else {
+                  # TODO: more robust check on column of primitive types
+                  vec <- do.call(c, col)
+                  if (class(vec) != "list") {
+                    df[[names[colIndex]]] <- vec                  
+                  } else {
+                    # For columns of complex type, be careful to access them.
+                    # Get a column of complex type returns a list.
+                    # Get a cell from a column of complex type returns a list 
instead of a vector.
+                    df[[names[colIndex]]] <- col
+                 }
+              }
+            }
+            df
+          }
+        })
 
 #' Limit
 #'

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 33bf13e..6cf628e 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -48,6 +48,7 @@ readTypedObject <- function(con, type) {
     "r" = readRaw(con),
     "D" = readDate(con),
     "t" = readTime(con),
+    "a" = readArray(con),
     "l" = readList(con),
     "n" = NULL,
     "j" = getJobj(readString(con)),
@@ -85,8 +86,7 @@ readTime <- function(con) {
   as.POSIXct(t, origin = "1970-01-01")
 }
 
-# We only support lists where all elements are of same type
-readList <- function(con) {
+readArray <- function(con) {
   type <- readType(con)
   len <- readInt(con)
   if (len > 0) {
@@ -100,6 +100,25 @@ readList <- function(con) {
   }
 }
 
+# Read a list. Types of each element may be different.
+# Null objects are read as NA.
+readList <- function(con) {
+  len <- readInt(con)
+  if (len > 0) {
+    l <- vector("list", len)
+    for (i in 1:len) {
+      elem <- readObject(con)
+      if (is.null(elem)) {
+        elem <- NA
+      }
+      l[[i]] <- elem
+    }
+    l
+  } else {
+    list()
+  }
+}
+
 readRaw <- function(con) {
   dataLen <- readInt(con)
   readBin(con, raw(), as.integer(dataLen), endian = "big")
@@ -132,18 +151,19 @@ readDeserialize <- function(con) {
   }
 }
 
-readDeserializeRows <- function(inputCon) {
-  # readDeserializeRows will deserialize a DataOutputStream composed of
-  # a list of lists. Since the DOS is one continuous stream and
-  # the number of rows varies, we put the readRow function in a while loop
-  # that termintates when the next row is empty.
+readMultipleObjects <- function(inputCon) {
+  # readMultipleObjects will read multiple continuous objects from
+  # a DataOutputStream. There is no preceding field telling the count
+  # of the objects, so the number of objects varies, we try to read
+  # all objects in a loop until the end of the stream.
   data <- list()
   while(TRUE) {
-    row <- readRow(inputCon)
-    if (length(row) == 0) {
+    # If reaching the end of the stream, type returned should be "".
+    type <- readType(inputCon)
+    if (type == "") {
       break
     }
-    data[[length(data) + 1L]] <- row
+    data[[length(data) + 1L]] <- readTypedObject(inputCon, type)
   }
   data # this is a list of named lists now
 }
@@ -155,35 +175,5 @@ readRowList <- function(obj) {
   # deserialize the row.
   rawObj <- rawConnection(obj, "r+")
   on.exit(close(rawObj))
-  readRow(rawObj)
-}
-
-readRow <- function(inputCon) {
-  numCols <- readInt(inputCon)
-  if (length(numCols) > 0 && numCols > 0) {
-    lapply(1:numCols, function(x) {
-      obj <- readObject(inputCon)
-      if (is.null(obj)) {
-        NA
-      } else {
-        obj
-      }
-    }) # each row is a list now
-  } else {
-    list()
-  }
-}
-
-# Take a single column as Array[Byte] and deserialize it into an atomic vector
-readCol <- function(inputCon, numRows) {
-  if (numRows > 0) {
-    # sapply can not work with POSIXlt
-    do.call(c, lapply(1:numRows, function(x) {
-      value <- readObject(inputCon)
-      # Replace NULL with NA so we can coerce to vectors
-      if (is.null(value)) NA else value
-    }))
-  } else {
-    vector()
-  }
+  readObject(rawObj)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/R/pkg/R/serialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 311021e..e3676f5 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) {
 serializeRow <- function(row) {
   rawObj <- rawConnection(raw(0), "wb")
   on.exit(close(rawObj))
-  writeRow(rawObj, row)
+  writeGenericList(rawObj, row)
   rawConnectionValue(rawObj)
 }
 
-writeRow <- function(con, row) {
-  numCols <- length(row)
-  writeInt(con, numCols)
-  for (i in 1:numCols) {
-    writeObject(con, row[[i]])
-  }
-}
-
 writeRaw <- function(con, batch) {
   writeInt(con, length(batch))
   writeBin(batch, con, endian = "big")

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/R/pkg/inst/tests/test_Serde.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R
new file mode 100644
index 0000000..009db85
--- /dev/null
+++ b/R/pkg/inst/tests/test_Serde.R
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("SerDe functionality")
+
+sc <- sparkR.init()
+
+test_that("SerDe of primitive types", {
+  x <- callJStatic("SparkRHandler", "echo", 1L)
+  expect_equal(x, 1L)
+  expect_equal(class(x), "integer")
+  
+  x <- callJStatic("SparkRHandler", "echo", 1)
+  expect_equal(x, 1)
+  expect_equal(class(x), "numeric")
+
+  x <- callJStatic("SparkRHandler", "echo", TRUE)
+  expect_true(x)
+  expect_equal(class(x), "logical")
+  
+  x <- callJStatic("SparkRHandler", "echo", "abc")
+  expect_equal(x, "abc")
+  expect_equal(class(x), "character")  
+})
+
+test_that("SerDe of list of primitive types", {
+  x <- list(1L, 2L, 3L)
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+  expect_equal(class(y[[1]]), "integer")
+
+  x <- list(1, 2, 3)
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+  expect_equal(class(y[[1]]), "numeric")
+  
+  x <- list(TRUE, FALSE)
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+  expect_equal(class(y[[1]]), "logical")
+  
+  x <- list("a", "b", "c")
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+  expect_equal(class(y[[1]]), "character")
+  
+  # Empty list
+  x <- list()
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+})
+
+test_that("SerDe of list of lists", {
+  x <- list(list(1L, 2L, 3L), list(1, 2, 3),
+            list(TRUE, FALSE), list("a", "b", "c"))
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+
+  # List of empty lists
+  x <- list(list(), list())
+  y <- callJStatic("SparkRHandler", "echo", x)
+  expect_equal(x, y)
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/R/pkg/inst/worker/worker.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 7e3b5fc..0c3b0d1 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -94,7 +94,7 @@ if (isEmpty != 0) {
     } else if (deserializer == "string") {
       data <- as.list(readLines(inputCon))
     } else if (deserializer == "row") {
-      data <- SparkR:::readDeserializeRows(inputCon)
+      data <- SparkR:::readMultipleObjects(inputCon)
     }
     # Timing reading input data for execution
     inputElap <- elapsedSecs()
@@ -120,7 +120,7 @@ if (isEmpty != 0) {
     } else if (deserializer == "string") {
       data <- readLines(inputCon)
     } else if (deserializer == "row") {
-      data <- SparkR:::readDeserializeRows(inputCon)
+      data <- SparkR:::readMultipleObjects(inputCon)
     }
     # Timing reading input data for execution
     inputElap <- elapsedSecs()

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 6ce02e2..bb82f32 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend)
 
     if (objId == "SparkRHandler") {
       methodName match {
+        // This function is for test-purpose only
+        case "echo" =>
+          val args = readArgs(numArgs, dis)
+          assert(numArgs == 1)
+
+          writeInt(dos, 0)
+          writeObject(dos, args(0))
         case "stopBackend" =>
           writeInt(dos, 0)
           writeType(dos, "void")

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala 
b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index dbbbcf4..26ad4f1 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -149,6 +149,10 @@ private[spark] object SerDe {
       case 'b' => readBooleanArr(dis)
       case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
       case 'r' => readBytesArr(dis)
+      case 'l' => {
+        val len = readInt(dis)
+        (0 until len).map(_ => readList(dis)).toArray
+      }
       case _ => throw new IllegalArgumentException(s"Invalid array type 
$arrType")
     }
   }
@@ -200,6 +204,9 @@ private[spark] object SerDe {
       case "date" => dos.writeByte('D')
       case "time" => dos.writeByte('t')
       case "raw" => dos.writeByte('r')
+      // Array of primitive types
+      case "array" => dos.writeByte('a')
+      // Array of objects
       case "list" => dos.writeByte('l')
       case "jobj" => dos.writeByte('j')
       case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
@@ -211,26 +218,35 @@ private[spark] object SerDe {
       writeType(dos, "void")
     } else {
       value.getClass.getName match {
+        case "java.lang.Character" =>
+          writeType(dos, "character")
+          writeString(dos, value.asInstanceOf[Character].toString)
         case "java.lang.String" =>
           writeType(dos, "character")
           writeString(dos, value.asInstanceOf[String])
-        case "long" | "java.lang.Long" =>
+        case "java.lang.Long" =>
           writeType(dos, "double")
           writeDouble(dos, value.asInstanceOf[Long].toDouble)
-        case "float" | "java.lang.Float" =>
+        case "java.lang.Float" =>
           writeType(dos, "double")
           writeDouble(dos, value.asInstanceOf[Float].toDouble)
-        case "decimal" | "java.math.BigDecimal" =>
+        case "java.math.BigDecimal" =>
           writeType(dos, "double")
           val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
           writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
-        case "double" | "java.lang.Double" =>
+        case "java.lang.Double" =>
           writeType(dos, "double")
           writeDouble(dos, value.asInstanceOf[Double])
-        case "int" | "java.lang.Integer" =>
+        case "java.lang.Byte" =>
+          writeType(dos, "integer")
+          writeInt(dos, value.asInstanceOf[Byte].toInt)
+        case "java.lang.Short" =>
+          writeType(dos, "integer")
+          writeInt(dos, value.asInstanceOf[Short].toInt)
+        case "java.lang.Integer" =>
           writeType(dos, "integer")
           writeInt(dos, value.asInstanceOf[Int])
-        case "boolean" | "java.lang.Boolean" =>
+        case "java.lang.Boolean" =>
           writeType(dos, "logical")
           writeBoolean(dos, value.asInstanceOf[Boolean])
         case "java.sql.Date" =>
@@ -242,43 +258,48 @@ private[spark] object SerDe {
         case "java.sql.Timestamp" =>
           writeType(dos, "time")
           writeTime(dos, value.asInstanceOf[Timestamp])
+
+        // Handle arrays
+
+        // Array of primitive types
+
+        // Special handling for byte array
         case "[B" =>
           writeType(dos, "raw")
           writeBytes(dos, value.asInstanceOf[Array[Byte]])
-        // TODO: Types not handled right now include
-        // byte, char, short, float
 
-        // Handle arrays
-        case "[Ljava.lang.String;" =>
-          writeType(dos, "list")
-          writeStringArr(dos, value.asInstanceOf[Array[String]])
+        case "[C" =>
+          writeType(dos, "array")
+          writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
+        case "[S" =>
+          writeType(dos, "array")
+          writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
         case "[I" =>
-          writeType(dos, "list")
+          writeType(dos, "array")
           writeIntArr(dos, value.asInstanceOf[Array[Int]])
         case "[J" =>
-          writeType(dos, "list")
+          writeType(dos, "array")
           writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
+        case "[F" =>
+          writeType(dos, "array")
+          writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
         case "[D" =>
-          writeType(dos, "list")
+          writeType(dos, "array")
           writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
         case "[Z" =>
-          writeType(dos, "list")
+          writeType(dos, "array")
           writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
-        case "[[B" =>
+
+        // Array of objects, null objects use "void" type
+        case c if c.startsWith("[") =>
           writeType(dos, "list")
-          writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
-        case otherName =>
-          // Handle array of objects
-          if (otherName.startsWith("[L")) {
-            val objArr = value.asInstanceOf[Array[Object]]
-            writeType(dos, "list")
-            writeType(dos, "jobj")
-            dos.writeInt(objArr.length)
-            objArr.foreach(o => writeJObj(dos, o))
-          } else {
-            writeType(dos, "jobj")
-            writeJObj(dos, value)
-          }
+          val array = value.asInstanceOf[Array[Object]]
+          writeInt(dos, array.length)
+          array.foreach(elem => writeObject(dos, elem))
+
+        case _ =>
+          writeType(dos, "jobj")
+          writeJObj(dos, value)
       }
     }
   }
@@ -350,11 +371,6 @@ private[spark] object SerDe {
     value.foreach(v => writeString(out, v))
   }
 
-  def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
-    writeType(out, "raw")
-    out.writeInt(value.length)
-    value.foreach(v => writeBytes(out, v))
-  }
 }
 
 private[r] object SerializationFormats {

http://git-wip-us.apache.org/repos/asf/spark/blob/71a138cd/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 92861ab..7f3defe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -98,27 +98,17 @@ private[r] object SQLUtils {
     val bos = new ByteArrayOutputStream()
     val dos = new DataOutputStream(bos)
 
-    SerDe.writeInt(dos, row.length)
-    (0 until row.length).map { idx =>
-      val obj: Object = row(idx).asInstanceOf[Object]
-      SerDe.writeObject(dos, obj)
-    }
+    val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
+    SerDe.writeObject(dos, cols)
     bos.toByteArray()
   }
 
-  def dfToCols(df: DataFrame): Array[Array[Byte]] = {
+  def dfToCols(df: DataFrame): Array[Array[Any]] = {
     // localDF is Array[Row]
     val localDF = df.collect()
     val numCols = df.columns.length
-    // dfCols is Array[Array[Any]]
-    val dfCols = convertRowsToColumns(localDF, numCols)
-
-    dfCols.map { col =>
-      colToRBytes(col)
-    }
-  }
 
-  def convertRowsToColumns(localDF: Array[Row], numCols: Int): 
Array[Array[Any]] = {
+    // result is Array[Array[Any]]
     (0 until numCols).map { colIdx =>
       localDF.map { row =>
         row(colIdx)
@@ -126,20 +116,6 @@ private[r] object SQLUtils {
     }.toArray
   }
 
-  def colToRBytes(col: Array[Any]): Array[Byte] = {
-    val numRows = col.length
-    val bos = new ByteArrayOutputStream()
-    val dos = new DataOutputStream(bos)
-
-    SerDe.writeInt(dos, numRows)
-
-    col.map { item =>
-      val obj: Object = item.asInstanceOf[Object]
-      SerDe.writeObject(dos, obj)
-    }
-    bos.toByteArray()
-  }
-
   def saveMode(mode: String): SaveMode = {
     mode match {
       case "append" => SaveMode.Append


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to