Repository: spark
Updated Branches:
  refs/heads/master 5dbaf3d39 -> 896edb51a


[SPARK-10050] [SPARKR] Support collecting data of MapType in DataFrame.

1. Support collecting data of MapType from DataFrame.
2. Support data of MapType in createDataFrame.

Author: Sun Rui <rui....@intel.com>

Closes #8711 from sun-rui/SPARK-10050.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/896edb51
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/896edb51
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/896edb51

Branch: refs/heads/master
Commit: 896edb51ab7a88bbb31259e565311a9be6f2ca6d
Parents: 5dbaf3d
Author: Sun Rui <rui....@intel.com>
Authored: Wed Sep 16 13:20:39 2015 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Wed Sep 16 13:20:39 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/SQLContext.R                            |  5 +-
 R/pkg/R/deserialize.R                           | 14 +++++
 R/pkg/R/schema.R                                | 34 +++++++++---
 R/pkg/inst/tests/test_sparkSQL.R                | 56 +++++++++++++++-----
 .../scala/org/apache/spark/api/r/SerDe.scala    | 31 +++++++++++
 .../org/apache/spark/sql/api/r/SQLUtils.scala   |  6 +++
 6 files changed, 123 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/R/pkg/R/SQLContext.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 4ac057d..1c58fd9 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -41,10 +41,7 @@ infer_type <- function(x) {
   if (type == "map") {
     stopifnot(length(x) > 0)
     key <- ls(x)[[1]]
-    list(type = "map",
-         keyType = "string",
-         valueType = infer_type(get(key, x)),
-         valueContainsNull = TRUE)
+    paste0("map<string,", infer_type(get(key, x)), ">")
   } else if (type == "array") {
     stopifnot(length(x) > 0)
     names <- names(x)

http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index d1858ec..ce88d0b 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -50,6 +50,7 @@ readTypedObject <- function(con, type) {
     "t" = readTime(con),
     "a" = readArray(con),
     "l" = readList(con),
+    "e" = readEnv(con),
     "n" = NULL,
     "j" = getJobj(readString(con)),
     stop(paste("Unsupported type for deserialization", type)))
@@ -121,6 +122,19 @@ readList <- function(con) {
   }
 }
 
+readEnv <- function(con) {
+  env <- new.env()
+  len <- readInt(con)
+  if (len > 0) {
+    for (i in 1:len) {
+      key <- readString(con)
+      value <- readObject(con)
+      env[[key]] <- value
+    }
+  }
+  env
+}
+
 readRaw <- function(con) {
   dataLen <- readInt(con)
   readBin(con, raw(), as.integer(dataLen), endian = "big")

http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/R/pkg/R/schema.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 62d4f73..8df1563 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -131,13 +131,33 @@ checkType <- function(type) {
   if (type %in% primtiveTypes) {
     return()
   } else {
-    m <- regexec("^array<(.*)>$", type)
-    matchedStrings <- regmatches(type, m)
-    if (length(matchedStrings[[1]]) >= 2) {
-      elemType <- matchedStrings[[1]][2]
-      checkType(elemType)
-      return()
-    }
+    # Check complex types
+    firstChar <- substr(type, 1, 1)
+    switch (firstChar,
+            a = {
+              # Array type
+              m <- regexec("^array<(.*)>$", type)
+              matchedStrings <- regmatches(type, m)
+              if (length(matchedStrings[[1]]) >= 2) {
+                elemType <- matchedStrings[[1]][2]
+                checkType(elemType)
+                return()
+              }
+            },
+            m = {
+              # Map type
+              m <- regexec("^map<(.*),(.*)>$", type)
+              matchedStrings <- regmatches(type, m)
+              if (length(matchedStrings[[1]]) >= 3) {
+                keyType <- matchedStrings[[1]][2]
+                if (keyType != "string" && keyType != "character") {
+                  stop("Key type in a map must be string or character")
+                }
+                valueType <- matchedStrings[[1]][3]
+                checkType(valueType)
+                return()
+              }
+            })
   }
 
   stop(paste("Unsupported type for Dataframe:", type))

http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 98d4402..e159a69 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -57,7 +57,7 @@ mockLinesComplexType <-
 complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
 writeLines(mockLinesComplexType, complexTypeJsonPath)
 
-test_that("infer types", {
+test_that("infer types and check types", {
   expect_equal(infer_type(1L), "integer")
   expect_equal(infer_type(1.0), "double")
   expect_equal(infer_type("abc"), "string")
@@ -72,9 +72,9 @@ test_that("infer types", {
   checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE)
   e <- new.env()
   assign("a", 1L, envir = e)
-  expect_equal(infer_type(e),
-               list(type = "map", keyType = "string", valueType = "integer",
-                    valueContainsNull = TRUE))
+  expect_equal(infer_type(e), "map<string,integer>")
+
+  expect_error(checkType("map<integer,integer>"), "Key type in a map must be 
string or character")
 })
 
 test_that("structType and structField", {
@@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", {
   expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
 })
 
-test_that("create DataFrame with nested array and struct", {
+test_that("create DataFrame with nested array and map", {
 #  e <- new.env()
 #  assign("n", 3L, envir = e)
 #  l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
@@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and 
struct", {
 #  ldf <- collect(df)
 #  expect_equal(ldf[1,], l[[1]])
 
+  #  ArrayType and MapType
+  e <- new.env()
+  assign("n", 3L, envir = e)
 
-  #  ArrayType only for now
-  l <- list(as.list(1:10), list("a", "b"))
-  df <- createDataFrame(sqlContext, list(l), c("a", "b"))
-  expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>")))
+  l <- list(as.list(1:10), list("a", "b"), e)
+  df <- createDataFrame(sqlContext, list(l), c("a", "b", "c"))
+  expect_equal(dtypes(df), list(c("a", "array<int>"),
+                                c("b", "array<string>"),
+                                c("c", "map<string,int>")))
   expect_equal(count(df), 1)
   ldf <- collect(df)
-  expect_equal(names(ldf), c("a", "b"))
+  expect_equal(names(ldf), c("a", "b", "c"))
   expect_equal(ldf[1, 1][[1]], l[[1]])
   expect_equal(ldf[1, 2][[1]], l[[2]])
+  e <- ldf$c[[1]]
+  expect_equal(class(e), "environment")
+  expect_equal(ls(e), "n")
+  expect_equal(e$n, 3L)
 })
 
+# For test map type in DataFrame
+mockLinesMapType <- 
c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}",
+                      
"{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}",
+                      
"{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}")
+mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesMapType, mapTypeJsonPath)
+
 test_that("Collect DataFrame with complex types", {
-  # only ArrayType now
-  # TODO: tests for StructType and MapType after they are supported
+  # ArrayType
   df <- jsonFile(sqlContext, complexTypeJsonPath)
 
   ldf <- collect(df)
@@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", {
   expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9)))
   expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list 
("g", "h", "i")))
   expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list 
(7.0, 8.0, 9.0)))
+
+  # MapType
+  schema <- structType(structField("name", "string"),
+                       structField("info", "map<string,double>"))
+  df <- read.df(sqlContext, mapTypeJsonPath, "json", schema)
+  expect_equal(dtypes(df), list(c("name", "string"),
+                                c("info", "map<string,double>")))
+  ldf <- collect(df)
+  expect_equal(nrow(ldf), 3)
+  expect_equal(ncol(ldf), 2)
+  expect_equal(names(ldf), c("name", "info"))
+  expect_equal(ldf$name, c("Bob", "Alice", "David"))
+  bob <- ldf$info[[1]]
+  expect_equal(class(bob), "environment")
+  expect_equal(bob$age, 16)
+  expect_equal(bob$height, 176.5)
+
+  # TODO: tests for StructType after it is supported
 })
 
 test_that("jsonFile() on a local file returns a DataFrame", {

http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala 
b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 3c92bb7..0c78613 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -209,11 +209,23 @@ private[spark] object SerDe {
       case "array" => dos.writeByte('a')
       // Array of objects
       case "list" => dos.writeByte('l')
+      case "map" => dos.writeByte('e')
       case "jobj" => dos.writeByte('j')
       case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
     }
   }
 
+  private def writeKeyValue(dos: DataOutputStream, key: Object, value: 
Object): Unit = {
+    if (key == null) {
+      throw new IllegalArgumentException("Key in map can't be null.")
+    } else if (!key.isInstanceOf[String]) {
+      throw new IllegalArgumentException(s"Invalid map key type: 
${key.getClass.getName}")
+    }
+
+    writeString(dos, key.asInstanceOf[String])
+    writeObject(dos, value)
+  }
+
   def writeObject(dos: DataOutputStream, obj: Object): Unit = {
     if (obj == null) {
       writeType(dos, "void")
@@ -306,6 +318,25 @@ private[spark] object SerDe {
           writeInt(dos, v.length)
           v.foreach(elem => writeObject(dos, elem))
 
+        // Handle map
+        case v: java.util.Map[_, _] =>
+          writeType(dos, "map")
+          writeInt(dos, v.size)
+          val iter = v.entrySet.iterator
+          while(iter.hasNext) {
+            val entry = iter.next
+            val key = entry.getKey
+            val value = entry.getValue
+
+            writeKeyValue(dos, key.asInstanceOf[Object], 
value.asInstanceOf[Object])
+          }
+        case v: scala.collection.Map[_, _] =>
+          writeType(dos, "map")
+          writeInt(dos, v.size)
+          v.foreach { case (key, value) =>
+            writeKeyValue(dos, key.asInstanceOf[Object], 
value.asInstanceOf[Object])
+          }
+
         case _ =>
           writeType(dos, "jobj")
           writeJObj(dos, value)

http://git-wip-us.apache.org/repos/asf/spark/blob/896edb51/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index d4b834a..f45d119 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -64,6 +64,12 @@ private[r] object SQLUtils {
       case r"\Aarray<(.*)${elemType}>\Z" => {
         org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
       }
+      case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => {
+        if (keyType != "string" && keyType != "character") {
+          throw new IllegalArgumentException("Key type of a map must be string 
or character")
+        }
+        org.apache.spark.sql.types.MapType(getSQLDataType(keyType), 
getSQLDataType(valueType))
+      }
       case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to