Repository: spark
Updated Branches:
  refs/heads/master d88abb7e2 -> 45e3be5c1


[SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame.

this PR :
1.  Enhance reflection in RBackend. Automatically matching a Java array to 
Scala Seq when finding methods. Util functions like seq(), listToSeq() in R 
side can be removed, as they will conflict with the Serde logic that transferrs 
a Scala seq to R side.

2.  Enhance the SerDe to support transferring  a Scala seq to R side. Data of 
ArrayType in DataFrame
after collection is observed to be of Scala Seq type.

3.  Support ArrayType in createDataFrame().

Author: Sun Rui <rui....@intel.com>

Closes #8458 from sun-rui/SPARK-10049.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45e3be5c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45e3be5c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45e3be5c

Branch: refs/heads/master
Commit: 45e3be5c138d983f40f619735d60bf7eb78c9bf0
Parents: d88abb7
Author: Sun Rui <rui....@intel.com>
Authored: Thu Sep 10 12:21:13 2015 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Thu Sep 10 12:21:13 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/DataFrame.R                             |  26 ++--
 R/pkg/R/SQLContext.R                            |   4 +-
 R/pkg/R/column.R                                |   3 +-
 R/pkg/R/functions.R                             |  12 +-
 R/pkg/R/group.R                                 |   4 +-
 R/pkg/R/schema.R                                |  54 ++++++---
 R/pkg/R/utils.R                                 |  10 --
 R/pkg/inst/tests/test_sparkSQL.R                |  44 +++++--
 .../apache/spark/api/r/RBackendHandler.scala    | 121 +++++++++++++------
 .../scala/org/apache/spark/api/r/SerDe.scala    | 109 +++++++++--------
 .../org/apache/spark/sql/api/r/SQLUtils.scala   |  14 ++-
 11 files changed, 250 insertions(+), 151 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 8a00238..c3c1893 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -271,7 +271,7 @@ setMethod("names<-",
           signature(x = "DataFrame"),
           function(x, value) {
             if (!is.null(value)) {
-              sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
+              sdf <- callJMethod(x@sdf, "toDF", as.list(value))
               dataFrame(sdf)
             }
           })
@@ -843,10 +843,10 @@ setMethod("groupBy",
            function(x, ...) {
              cols <- list(...)
              if (length(cols) >= 1 && class(cols[[1]]) == "character") {
-               sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], 
listToSeq(cols[-1]))
+               sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1])
              } else {
                jcol <- lapply(cols, function(c) { c@jc })
-               sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
+               sgd <- callJMethod(x@sdf, "groupBy", jcol)
              }
              groupedData(sgd)
            })
@@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"),
 #' }
 setMethod("select", signature(x = "DataFrame", col = "character"),
           function(x, col, ...) {
-            sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
+            sdf <- callJMethod(x@sdf, "select", col, list(...))
             dataFrame(sdf)
           })
 
@@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = 
"Column"),
             jcols <- lapply(list(col, ...), function(c) {
               c@jc
             })
-            sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
+            sdf <- callJMethod(x@sdf, "select", jcols)
             dataFrame(sdf)
           })
 
@@ -1106,7 +1106,7 @@ setMethod("select",
                 col(c)@jc
               }
             })
-            sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
+            sdf <- callJMethod(x@sdf, "select", cols)
             dataFrame(sdf)
           })
 
@@ -1133,7 +1133,7 @@ setMethod("selectExpr",
           signature(x = "DataFrame", expr = "character"),
           function(x, expr, ...) {
             exprList <- list(expr, ...)
-            sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
+            sdf <- callJMethod(x@sdf, "selectExpr", exprList)
             dataFrame(sdf)
           })
 
@@ -1311,12 +1311,12 @@ setMethod("arrange",
           signature(x = "DataFrame", col = "characterOrColumn"),
           function(x, col, ...) {
             if (class(col) == "character") {
-              sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
+              sdf <- callJMethod(x@sdf, "sort", col, list(...))
             } else if (class(col) == "Column") {
               jcols <- lapply(list(col, ...), function(c) {
                 c@jc
               })
-              sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
+              sdf <- callJMethod(x@sdf, "sort", jcols)
             }
             dataFrame(sdf)
           })
@@ -1664,7 +1664,7 @@ setMethod("describe",
           signature(x = "DataFrame", col = "character"),
           function(x, col, ...) {
             colList <- list(col, ...)
-            sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+            sdf <- callJMethod(x@sdf, "describe", colList)
             dataFrame(sdf)
           })
 
@@ -1674,7 +1674,7 @@ setMethod("describe",
           signature(x = "DataFrame"),
           function(x) {
             colList <- as.list(c(columns(x)))
-            sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+            sdf <- callJMethod(x@sdf, "describe", colList)
             dataFrame(sdf)
           })
 
@@ -1731,7 +1731,7 @@ setMethod("dropna",
 
             naFunctions <- callJMethod(x@sdf, "na")
             sdf <- callJMethod(naFunctions, "drop",
-                               as.integer(minNonNulls), 
listToSeq(as.list(cols)))
+                               as.integer(minNonNulls), as.list(cols))
             dataFrame(sdf)
           })
 
@@ -1815,7 +1815,7 @@ setMethod("fillna",
             sdf <- if (length(cols) == 0) {
               callJMethod(naFunctions, "fill", value)
             } else {
-              callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+              callJMethod(naFunctions, "fill", value, as.list(cols))
             }
             dataFrame(sdf)
           })

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/SQLContext.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 1bc6445..4ac057d 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -49,7 +49,7 @@ infer_type <- function(x) {
     stopifnot(length(x) > 0)
     names <- names(x)
     if (is.null(names)) {
-      list(type = "array", elementType = infer_type(x[[1]]), containsNull = 
TRUE)
+      paste0("array<", infer_type(x[[1]]), ">")
     } else {
       # StructType
       types <- lapply(x, infer_type)
@@ -59,7 +59,7 @@ infer_type <- function(x) {
       do.call(structType, fields)
     }
   } else if (length(x) > 1) {
-    list(type = "array", elementType = type, containsNull = TRUE)
+    paste0("array<", infer_type(x[[1]]), ">")
   } else {
     type
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/column.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 4805096..42e9d12 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -211,8 +211,7 @@ setMethod("cast",
 setMethod("%in%",
           signature(x = "Column"),
           function(x, table) {
-            table <- listToSeq(as.list(table))
-            jc <- callJMethod(x@jc, "in", table)
+            jc <- callJMethod(x@jc, "in", as.list(table))
             return(column(jc))
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d848730..94687ed 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1331,7 +1331,7 @@ setMethod("countDistinct",
               x@jc
             })
             jc <- callJStatic("org.apache.spark.sql.functions", 
"countDistinct", x@jc,
-                              listToSeq(jcol))
+                              jcol)
             column(jc)
           })
 
@@ -1348,7 +1348,7 @@ setMethod("concat",
           signature(x = "Column"),
           function(x, ...) {
             jcols <- lapply(list(x, ...), function(x) { x@jc })
-            jc <- callJStatic("org.apache.spark.sql.functions", "concat", 
listToSeq(jcols))
+            jc <- callJStatic("org.apache.spark.sql.functions", "concat", 
jcols)
             column(jc)
           })
 
@@ -1366,7 +1366,7 @@ setMethod("greatest",
           function(x, ...) {
             stopifnot(length(list(...)) > 0)
             jcols <- lapply(list(x, ...), function(x) { x@jc })
-            jc <- callJStatic("org.apache.spark.sql.functions", "greatest", 
listToSeq(jcols))
+            jc <- callJStatic("org.apache.spark.sql.functions", "greatest", 
jcols)
             column(jc)
           })
 
@@ -1384,7 +1384,7 @@ setMethod("least",
           function(x, ...) {
             stopifnot(length(list(...)) > 0)
             jcols <- lapply(list(x, ...), function(x) { x@jc })
-            jc <- callJStatic("org.apache.spark.sql.functions", "least", 
listToSeq(jcols))
+            jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols)
             column(jc)
           })
 
@@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x 
= "numeric"),
 #' @export
 setMethod("concat_ws", signature(sep = "character", x = "Column"),
           function(sep, x, ...) {
-            jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc }))
+            jcols <- lapply(list(x, ...), function(x) { x@jc })
             jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", 
sep, jcols)
             column(jc)
           })
@@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"),
 #' @export
 setMethod("format_string", signature(format = "character", x = "Column"),
           function(format, x, ...) {
-            jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc }))
+            jcols <- lapply(list(x, ...), function(arg) { arg@jc })
             jc <- callJStatic("org.apache.spark.sql.functions",
                               "format_string",
                               format, jcols)

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 576ac72..4cab1a6 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -102,7 +102,7 @@ setMethod("agg",
                 }
               }
               jcols <- lapply(cols, function(c) { c@jc })
-              sdf <- callJMethod(x@sgd, "agg", jcols[[1]], 
listToSeq(jcols[-1]))
+              sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
             } else {
               stop("agg can only support Column or character")
             }
@@ -124,7 +124,7 @@ createMethod <- function(name) {
   setMethod(name,
             signature(x = "GroupedData"),
             function(x, ...) {
-              sdf <- callJMethod(x@sgd, name, toSeq(...))
+              sdf <- callJMethod(x@sgd, name, list(...))
               dataFrame(sdf)
             })
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/schema.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 79c744e..62d4f73 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -56,7 +56,7 @@ structType.structField <- function(x, ...) {
   })
   stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
                        "createStructType",
-                       listToSeq(sfObjList))
+                       sfObjList)
   structType(stObj)
 }
 
@@ -114,6 +114,35 @@ structField.jobj <- function(x) {
   obj
 }
 
+checkType <- function(type) {
+  primtiveTypes <- c("byte",
+                     "integer",
+                     "float",
+                     "double",
+                     "numeric",
+                     "character",
+                     "string",
+                     "binary",
+                     "raw",
+                     "logical",
+                     "boolean",
+                     "timestamp",
+                     "date")
+  if (type %in% primtiveTypes) {
+    return()
+  } else {
+    m <- regexec("^array<(.*)>$", type)
+    matchedStrings <- regmatches(type, m)
+    if (length(matchedStrings[[1]]) >= 2) {
+      elemType <- matchedStrings[[1]][2]
+      checkType(elemType)
+      return()
+    }
+  }
+
+  stop(paste("Unsupported type for Dataframe:", type))
+}
+
 structField.character <- function(x, type, nullable = TRUE) {
   if (class(x) != "character") {
     stop("Field name must be a string.")
@@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = 
TRUE) {
   if (class(nullable) != "logical") {
     stop("nullable must be either TRUE or FALSE")
   }
-  options <- c("byte",
-               "integer",
-               "float",
-               "double",
-               "numeric",
-               "character",
-               "string",
-               "binary",
-               "raw",
-               "logical",
-               "boolean",
-               "timestamp",
-               "date")
-  dataType <- if (type %in% options) {
-    type
-  } else {
-    stop(paste("Unsupported type for Dataframe:", type))
-  }
+
+  checkType(type)
+
   sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
                        "createStructField",
                        x,
-                       dataType,
+                       type,
                        nullable)
   structField(sfObj)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 3babcb5..69a2bc7 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -361,16 +361,6 @@ numToInt <- function(num) {
   as.integer(num)
 }
 
-# create a Seq in JVM
-toSeq <- function(...) {
-  callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
-}
-
-# create a Seq in JVM from a list
-listToSeq <- function(l) {
-  callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
-}
-
 # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
 # user defined function (UDF), and to examine variables in the UDF to decide
 # if their values should be included in the new function environment.

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 6d331f9..1ccfde5 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -49,6 +49,14 @@ mockLinesNa <- 
c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
 jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
 writeLines(mockLinesNa, jsonPathNa)
 
+# For test complex types in DataFrame
+mockLinesComplexType <-
+  c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}",
+    "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}",
+    "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}")
+complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesComplexType, complexTypeJsonPath)
+
 test_that("infer types", {
   expect_equal(infer_type(1L), "integer")
   expect_equal(infer_type(1.0), "double")
@@ -56,10 +64,8 @@ test_that("infer types", {
   expect_equal(infer_type(TRUE), "boolean")
   expect_equal(infer_type(as.Date("2015-03-11")), "date")
   expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
-  expect_equal(infer_type(c(1L, 2L)),
-               list(type = "array", elementType = "integer", containsNull = 
TRUE))
-  expect_equal(infer_type(list(1L, 2L)),
-               list(type = "array", elementType = "integer", containsNull = 
TRUE))
+  expect_equal(infer_type(c(1L, 2L)), "array<integer>")
+  expect_equal(infer_type(list(1L, 2L)), "array<integer>")
   testStruct <- infer_type(list(a = 1L, b = "2"))
   expect_equal(class(testStruct), "structType")
   checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
@@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", {
   expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
 })
 
-# TODO: enable this test after fix serialization for nested object
-#test_that("create DataFrame with nested array and struct", {
+test_that("create DataFrame with nested array and struct", {
 #  e <- new.env()
 #  assign("n", 3L, envir = e)
 #  l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
@@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", {
 #  expect_equal(count(df), 1)
 #  ldf <- collect(df)
 #  expect_equal(ldf[1,], l[[1]])
-#})
+
+
+  #  ArrayType only for now
+  l <- list(as.list(1:10), list("a", "b"))
+  df <- createDataFrame(sqlContext, list(l), c("a", "b"))
+  expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>")))
+  expect_equal(count(df), 1)
+  ldf <- collect(df)
+  expect_equal(names(ldf), c("a", "b"))
+  expect_equal(ldf[1, 1][[1]], l[[1]])
+  expect_equal(ldf[1, 2][[1]], l[[2]])
+})
+
+test_that("Collect DataFrame with complex types", {
+  # only ArrayType now
+  # TODO: tests for StructType and MapType after they are supported
+  df <- jsonFile(sqlContext, complexTypeJsonPath)
+
+  ldf <- collect(df)
+  expect_equal(nrow(ldf), 3)
+  expect_equal(ncol(ldf), 3)
+  expect_equal(names(ldf), c("c1", "c2", "c3"))
+  expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9)))
+  expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list 
("g", "h", "i")))
+  expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list 
(7.0, 8.0, 9.0)))
+})
 
 test_that("jsonFile() on a local file returns a DataFrame", {
   df <- jsonFile(sqlContext, jsonPath)

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index bb82f32..2a792d8 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend)
       val methods = cls.getMethods
       val selectedMethods = methods.filter(m => m.getName == methodName)
       if (selectedMethods.length > 0) {
-        val methods = selectedMethods.filter { x =>
-          matchMethod(numArgs, args, x.getParameterTypes)
-        }
-        if (methods.isEmpty) {
+        val index = findMatchedSignature(
+          selectedMethods.map(_.getParameterTypes),
+          args)
+
+        if (index.isEmpty) {
           logWarning(s"cannot find matching method ${cls}.$methodName. "
             + s"Candidates are:")
           selectedMethods.foreach { method =>
@@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend)
           }
           throw new Exception(s"No matched method found for $cls.$methodName")
         }
-        val ret = methods.head.invoke(obj, args : _*)
+
+        val ret = selectedMethods(index.get).invoke(obj, args : _*)
 
         // Write status bit
         writeInt(dos, 0)
         writeObject(dos, ret.asInstanceOf[AnyRef])
       } else if (methodName == "<init>") {
         // methodName should be "<init>" for constructor
-        val ctor = cls.getConstructors.filter { x =>
-          matchMethod(numArgs, args, x.getParameterTypes)
-        }.head
+        val ctors = cls.getConstructors
+        val index = findMatchedSignature(
+          ctors.map(_.getParameterTypes),
+          args)
 
-        val obj = ctor.newInstance(args : _*)
+        if (index.isEmpty) {
+          logWarning(s"cannot find matching constructor for ${cls}. "
+            + s"Candidates are:")
+          ctors.foreach { ctor =>
+            logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})")
+          }
+          throw new Exception(s"No matched constructor found for $cls")
+        }
+
+        val obj = ctors(index.get).newInstance(args : _*)
 
         writeInt(dos, 0)
         writeObject(dos, obj.asInstanceOf[AnyRef])
@@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend)
 
   // Read a number of arguments from the data input stream
   def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
-    (0 until numArgs).map { arg =>
+    (0 until numArgs).map { _ =>
       readObject(dis)
     }.toArray
   }
 
-  // Checks if the arguments passed in args matches the parameter types.
-  // NOTE: Currently we do exact match. We may add type conversions later.
-  def matchMethod(
-      numArgs: Int,
-      args: Array[java.lang.Object],
-      parameterTypes: Array[Class[_]]): Boolean = {
-    if (parameterTypes.length != numArgs) {
-      return false
-    }
+  // Find a matching method signature in an array of signatures of constructors
+  // or methods of the same name according to the passed arguments. Arguments
+  // may be converted in order to match a signature.
+  //
+  // Note that in Java reflection, constructors and normal methods are of 
different
+  // classes, and share no parent class that provides methods for reflection 
uses.
+  // There is no unified way to handle them in this function. So an array of 
signatures
+  // is passed in instead of an array of candidate constructors or methods.
+  //
+  // Returns an Option[Int] which is the index of the matched signature in the 
array.
+  def findMatchedSignature(
+      parameterTypesOfMethods: Array[Array[Class[_]]],
+      args: Array[Object]): Option[Int] = {
+    val numArgs = args.length
+
+    for (index <- 0 until parameterTypesOfMethods.length) {
+      val parameterTypes = parameterTypesOfMethods(index)
+
+      if (parameterTypes.length == numArgs) {
+        var argMatched = true
+        var i = 0
+        while (i < numArgs && argMatched) {
+          val parameterType = parameterTypes(i)
+
+          if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
+            // The case that the parameter type is a Scala Seq and the argument
+            // is a Java array is considered matching. The array will be 
converted
+            // to a Seq later if this method is matched.
+          } else {
+            var parameterWrapperType = parameterType
+
+            // Convert native parameters to Object types as args is 
Array[Object] here
+            if (parameterType.isPrimitive) {
+              parameterWrapperType = parameterType match {
+                case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+                case java.lang.Long.TYPE => classOf[java.lang.Integer]
+                case java.lang.Double.TYPE => classOf[java.lang.Double]
+                case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+                case _ => parameterType
+              }
+            }
+            if (!parameterWrapperType.isInstance(args(i))) {
+              argMatched = false
+            }
+          }
 
-    for (i <- 0 to numArgs - 1) {
-      val parameterType = parameterTypes(i)
-      var parameterWrapperType = parameterType
-
-      // Convert native parameters to Object types as args is Array[Object] 
here
-      if (parameterType.isPrimitive) {
-        parameterWrapperType = parameterType match {
-          case java.lang.Integer.TYPE => classOf[java.lang.Integer]
-          case java.lang.Long.TYPE => classOf[java.lang.Integer]
-          case java.lang.Double.TYPE => classOf[java.lang.Double]
-          case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
-          case _ => parameterType
+          i = i + 1
+        }
+
+        if (argMatched) {
+          // For now, we return the first matching method.
+          // TODO: find best method in matching methods.
+
+          // Convert args if needed
+          val parameterTypes = parameterTypesOfMethods(index)
+
+          (0 until numArgs).map { i =>
+            if (parameterTypes(i) == classOf[Seq[Any]] && 
args(i).getClass.isArray) {
+              // Convert a Java array to scala Seq
+              args(i) = args(i).asInstanceOf[Array[_]].toSeq
+            }
+          }
+
+          return Some(index)
         }
-      }
-      if (!parameterWrapperType.isInstance(args(i))) {
-        return false
       }
     }
-    true
+    None
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala 
b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 190e193..3c92bb7 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream}
 import java.sql.{Timestamp, Date, Time}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.WrappedArray
 
 /**
  * Utility functions to serialize, deserialize objects to / from R
@@ -213,89 +214,97 @@ private[spark] object SerDe {
     }
   }
 
-  def writeObject(dos: DataOutputStream, value: Object): Unit = {
-    if (value == null) {
+  def writeObject(dos: DataOutputStream, obj: Object): Unit = {
+    if (obj == null) {
       writeType(dos, "void")
     } else {
-      value.getClass.getName match {
-        case "java.lang.Character" =>
+      // Convert ArrayType collected from DataFrame to Java array
+      // Collected data of ArrayType from a DataFrame is observed to be of
+      // type "scala.collection.mutable.WrappedArray"
+      val value =
+        if (obj.isInstanceOf[WrappedArray[_]]) {
+          obj.asInstanceOf[WrappedArray[_]].toArray
+        } else {
+          obj
+        }
+
+      value match {
+        case v: java.lang.Character =>
           writeType(dos, "character")
-          writeString(dos, value.asInstanceOf[Character].toString)
-        case "java.lang.String" =>
+          writeString(dos, v.toString)
+        case v: java.lang.String =>
           writeType(dos, "character")
-          writeString(dos, value.asInstanceOf[String])
-        case "java.lang.Long" =>
+          writeString(dos, v)
+        case v: java.lang.Long =>
           writeType(dos, "double")
-          writeDouble(dos, value.asInstanceOf[Long].toDouble)
-        case "java.lang.Float" =>
+          writeDouble(dos, v.toDouble)
+        case v: java.lang.Float =>
           writeType(dos, "double")
-          writeDouble(dos, value.asInstanceOf[Float].toDouble)
-        case "java.math.BigDecimal" =>
+          writeDouble(dos, v.toDouble)
+        case v: java.math.BigDecimal =>
           writeType(dos, "double")
-          val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
-          writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
-        case "java.lang.Double" =>
+          writeDouble(dos, scala.math.BigDecimal(v).toDouble)
+        case v: java.lang.Double =>
           writeType(dos, "double")
-          writeDouble(dos, value.asInstanceOf[Double])
-        case "java.lang.Byte" =>
+          writeDouble(dos, v)
+        case v: java.lang.Byte =>
           writeType(dos, "integer")
-          writeInt(dos, value.asInstanceOf[Byte].toInt)
-        case "java.lang.Short" =>
+          writeInt(dos, v.toInt)
+        case v: java.lang.Short =>
           writeType(dos, "integer")
-          writeInt(dos, value.asInstanceOf[Short].toInt)
-        case "java.lang.Integer" =>
+          writeInt(dos, v.toInt)
+        case v: java.lang.Integer =>
           writeType(dos, "integer")
-          writeInt(dos, value.asInstanceOf[Int])
-        case "java.lang.Boolean" =>
+          writeInt(dos, v)
+        case v: java.lang.Boolean =>
           writeType(dos, "logical")
-          writeBoolean(dos, value.asInstanceOf[Boolean])
-        case "java.sql.Date" =>
+          writeBoolean(dos, v)
+        case v: java.sql.Date =>
           writeType(dos, "date")
-          writeDate(dos, value.asInstanceOf[Date])
-        case "java.sql.Time" =>
+          writeDate(dos, v)
+        case v: java.sql.Time =>
           writeType(dos, "time")
-          writeTime(dos, value.asInstanceOf[Time])
-        case "java.sql.Timestamp" =>
+          writeTime(dos, v)
+        case v: java.sql.Timestamp =>
           writeType(dos, "time")
-          writeTime(dos, value.asInstanceOf[Timestamp])
+          writeTime(dos, v)
 
         // Handle arrays
 
         // Array of primitive types
 
         // Special handling for byte array
-        case "[B" =>
+        case v: Array[Byte] =>
           writeType(dos, "raw")
-          writeBytes(dos, value.asInstanceOf[Array[Byte]])
+          writeBytes(dos, v)
 
-        case "[C" =>
+        case v: Array[Char] =>
           writeType(dos, "array")
-          writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
-        case "[S" =>
+          writeStringArr(dos, v.map(_.toString))
+        case v: Array[Short] =>
           writeType(dos, "array")
-          writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
-        case "[I" =>
+          writeIntArr(dos, v.map(_.toInt))
+        case v: Array[Int] =>
           writeType(dos, "array")
-          writeIntArr(dos, value.asInstanceOf[Array[Int]])
-        case "[J" =>
+          writeIntArr(dos, v)
+        case v: Array[Long] =>
           writeType(dos, "array")
-          writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
-        case "[F" =>
+          writeDoubleArr(dos, v.map(_.toDouble))
+        case v: Array[Float] =>
           writeType(dos, "array")
-          writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
-        case "[D" =>
+          writeDoubleArr(dos, v.map(_.toDouble))
+        case v: Array[Double] =>
           writeType(dos, "array")
-          writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
-        case "[Z" =>
+          writeDoubleArr(dos, v)
+        case v: Array[Boolean] =>
           writeType(dos, "array")
-          writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+          writeBooleanArr(dos, v)
 
         // Array of objects, null objects use "void" type
-        case c if c.startsWith("[") =>
+        case v: Array[Object] =>
           writeType(dos, "list")
-          val array = value.asInstanceOf[Array[Object]]
-          writeInt(dos, array.length)
-          array.foreach(elem => writeObject(dos, elem))
+          writeInt(dos, v.length)
+          v.foreach(elem => writeObject(dos, elem))
 
         case _ =>
           writeType(dos, "jobj")

http://git-wip-us.apache.org/repos/asf/spark/blob/45e3be5c/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 7f3defe..d4b834a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, 
Expression, NamedExpres
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, 
SaveMode}
 
+import scala.util.matching.Regex
+
 private[r] object SQLUtils {
   def createSQLContext(jsc: JavaSparkContext): SQLContext = {
     new SQLContext(jsc)
@@ -35,14 +37,15 @@ private[r] object SQLUtils {
     new JavaSparkContext(sqlCtx.sparkContext)
   }
 
-  def toSeq[T](arr: Array[T]): Seq[T] = {
-    arr.toSeq
-  }
-
   def createStructType(fields : Seq[StructField]): StructType = {
     StructType(fields)
   }
 
+  // Support using regex in string interpolation
+  private[this] implicit class RegexContext(sc: StringContext) {
+    def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): 
_*)
+  }
+
   def getSQLDataType(dataType: String): DataType = {
     dataType match {
       case "byte" => org.apache.spark.sql.types.ByteType
@@ -58,6 +61,9 @@ private[r] object SQLUtils {
       case "boolean" => org.apache.spark.sql.types.BooleanType
       case "timestamp" => org.apache.spark.sql.types.TimestampType
       case "date" => org.apache.spark.sql.types.DateType
+      case r"\Aarray<(.*)${elemType}>\Z" => {
+        org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
+      }
       case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to