Repository: spark
Updated Branches:
  refs/heads/master ee694cdff -> 70f1bcd7b


[SPARK-20493][R] De-duplicate parse logics for DDL-like type strings in R

## What changes were proposed in this pull request?

It seems we are using `SQLUtils.getSQLDataType` for type string in structField. 
It looks we can replace this with `CatalystSqlParser.parseDataType`.

They look similar DDL-like type definitions as below:

```scala
scala> Seq(Tuple1(Tuple1("a"))).toDF.show()
```
```
+---+
| _1|
+---+
|[a]|
+---+
```

```scala
scala> 
Seq(Tuple1(Tuple1("a"))).toDF.select($"_1".cast("struct<_1:string>")).show()
```
```
+---+
| _1|
+---+
|[a]|
+---+
```

Such type strings looks identical when R’s one as below:

```R
> write.df(sql("SELECT named_struct('_1', 'a') as struct"), "/tmp/aa", 
> "parquet")
> collect(read.df("/tmp/aa", "parquet", structType(structField("struct", 
> "struct<_1:string>"))))
  struct
1      a
```

R’s one is stricter because we are checking the types via regular expressions 
in R side ahead.

Actual logics there look a bit different but as we check it ahead in R side, it 
looks replacing it would not introduce (I think) no behaviour changes. To make 
this sure, the tests dedicated for it were added in SPARK-20105. (It looks 
`structField` is the only place that calls this method).

## How was this patch tested?

Existing tests - 
https://github.com/apache/spark/blob/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L143-L194
 should cover this.

Author: hyukjinkwon <gurwls...@gmail.com>

Closes #17785 from HyukjinKwon/SPARK-20493.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70f1bcd7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70f1bcd7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70f1bcd7

Branch: refs/heads/master
Commit: 70f1bcd7bcd42b30eabcf06a9639363f1ca4b449
Parents: ee694cd
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Sat Apr 29 11:02:17 2017 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Sat Apr 29 11:02:17 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/utils.R                                 |  8 ++++
 R/pkg/inst/tests/testthat/test_sparkSQL.R       | 13 +++++-
 R/pkg/inst/tests/testthat/test_utils.R          |  6 +--
 .../org/apache/spark/sql/api/r/SQLUtils.scala   | 43 +-------------------
 4 files changed, 24 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70f1bcd7/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index fbc89e9..d29af00 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -864,6 +864,14 @@ captureJVMException <- function(e, method) {
     # Extract the first message of JVM exception.
     first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
     stop(paste0(rmsg, "no such table - ", first), call. = FALSE)
+  } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", 
stacktrace))) {
+    msg <- strsplit(stacktrace, 
"org.apache.spark.sql.catalyst.parser.ParseException: ",
+                    fixed = TRUE)[[1]]
+    # Extract "Error in ..." message.
+    rmsg <- msg[1]
+    # Extract the first message of JVM exception.
+    first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
+    stop(paste0(rmsg, "parse error - ", first), call. = FALSE)
   } else {
     stop(stacktrace, call. = FALSE)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/70f1bcd7/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 2cef719..1a3d6df 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -150,7 +150,12 @@ test_that("structField type strings", {
                          binary = "BinaryType",
                          boolean = "BooleanType",
                          timestamp = "TimestampType",
-                         date = "DateType")
+                         date = "DateType",
+                         tinyint = "ByteType",
+                         smallint = "ShortType",
+                         int = "IntegerType",
+                         bigint = "LongType",
+                         decimal = "DecimalType(10,0)")
 
   complexTypes <- list("map<string,integer>" = 
"MapType(StringType,IntegerType,true)",
                        "array<string>" = "ArrayType(StringType,true)",
@@ -174,7 +179,11 @@ test_that("structField type strings", {
                           numeric = "numeric",
                           character = "character",
                           raw = "raw",
-                          logical = "logical")
+                          logical = "logical",
+                          short = "short",
+                          varchar = "varchar",
+                          long = "long",
+                          char = "char")
 
   complexErrors <- list("map<string, integer>" = " integer",
                         "array<String>" = "String",

http://git-wip-us.apache.org/repos/asf/spark/blob/70f1bcd7/R/pkg/inst/tests/testthat/test_utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_utils.R 
b/R/pkg/inst/tests/testthat/test_utils.R
index 6d006ec..1ca383d 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -167,13 +167,13 @@ test_that("convertToJSaveMode", {
 })
 
 test_that("captureJVMException", {
-  method <- "getSQLDataType"
+  method <- "createStructField"
   expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
method,
-                                    "unknown"),
+                                    "col", "unknown", TRUE),
                         error = function(e) {
                           captureJVMException(e, method)
                         }),
-               "Error in getSQLDataType : illegal argument - Invalid type 
unknown")
+               "parse error - .*DataType unknown.*not supported.")
 })
 
 test_that("hashCode", {

http://git-wip-us.apache.org/repos/asf/spark/blob/70f1bcd7/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index a26d004..d94e528 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.execution.command.ShowTablesCommand
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.sql.types._
@@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging {
     def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): 
_*)
   }
 
-  def getSQLDataType(dataType: String): DataType = {
-    dataType match {
-      case "byte" => org.apache.spark.sql.types.ByteType
-      case "integer" => org.apache.spark.sql.types.IntegerType
-      case "float" => org.apache.spark.sql.types.FloatType
-      case "double" => org.apache.spark.sql.types.DoubleType
-      case "numeric" => org.apache.spark.sql.types.DoubleType
-      case "character" => org.apache.spark.sql.types.StringType
-      case "string" => org.apache.spark.sql.types.StringType
-      case "binary" => org.apache.spark.sql.types.BinaryType
-      case "raw" => org.apache.spark.sql.types.BinaryType
-      case "logical" => org.apache.spark.sql.types.BooleanType
-      case "boolean" => org.apache.spark.sql.types.BooleanType
-      case "timestamp" => org.apache.spark.sql.types.TimestampType
-      case "date" => org.apache.spark.sql.types.DateType
-      case r"\Aarray<(.+)${elemType}>\Z" =>
-        org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
-      case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" =>
-        if (keyType != "string" && keyType != "character") {
-          throw new IllegalArgumentException("Key type of a map must be string 
or character")
-        }
-        org.apache.spark.sql.types.MapType(getSQLDataType(keyType), 
getSQLDataType(valueType))
-      case r"\Astruct<(.+)${fieldsStr}>\Z" =>
-        if (fieldsStr(fieldsStr.length - 1) == ',') {
-          throw new IllegalArgumentException(s"Invalid type $dataType")
-        }
-        val fields = fieldsStr.split(",")
-        val structFields = fields.map { field =>
-          field match {
-            case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" =>
-              createStructField(fieldName, fieldType, true)
-
-            case _ => throw new IllegalArgumentException(s"Invalid type 
$dataType")
-          }
-        }
-        createStructType(structFields)
-      case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
-    }
-  }
-
   def createStructField(name: String, dataType: String, nullable: Boolean): 
StructField = {
-    val dtObj = getSQLDataType(dataType)
+    val dtObj = CatalystSqlParser.parseDataType(dataType)
     StructField(name, dtObj, nullable)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to