This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bc6f191  [SPARK-24779][R] Add map_concat / map_from_entries / an 
option in months_between UDF to disable rounding-off
bc6f191 is described below

commit bc6f19145192835cdfa4fc263b1c35b294c1e0ac
Author: Huaxin Gao <huax...@us.ibm.com>
AuthorDate: Thu Jan 31 19:38:32 2019 +0800

    [SPARK-24779][R] Add map_concat / map_from_entries / an option in 
months_between UDF to disable rounding-off
    
    ## What changes were proposed in this pull request?
    
    Add the R version of map_concat / map_from_entries / an option in 
months_between UDF to disable rounding-off
    
    ## How was this patch tested?
    
    Add test in test_sparkSQL.R
    
    Closes #21835 from huaxingao/spark-24779.
    
    Authored-by: Huaxin Gao <huax...@us.ibm.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 R/pkg/NAMESPACE                       |  2 ++
 R/pkg/R/functions.R                   | 60 +++++++++++++++++++++++++++++++----
 R/pkg/R/generics.R                    | 10 +++++-
 R/pkg/tests/fulltests/test_sparkSQL.R | 22 +++++++++++++
 4 files changed, 87 insertions(+), 7 deletions(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index cfad20d..1dcad16 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -312,8 +312,10 @@ exportMethods("%<=>%",
               "lower",
               "lpad",
               "ltrim",
+              "map_concat",
               "map_entries",
               "map_from_arrays",
+              "map_from_entries",
               "map_keys",
               "map_values",
               "max",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 58fc410..8f425b1 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -80,6 +80,11 @@ NULL
 #'          \item \code{from_utc_timestamp}, \code{to_utc_timestamp}: time 
zone to use.
 #'          \item \code{next_day}: day of the week string.
 #'          }
+#' @param ... additional argument(s).
+#'          \itemize{
+#'          \item \code{months_between}, this contains an optional parameter 
to specify the
+#'              the result is rounded off to 8 digits.
+#'          }
 #'
 #' @name column_datetime_diff_functions
 #' @rdname column_datetime_diff_functions
@@ -217,6 +222,7 @@ NULL
 #'              additional named properties to control how it is converted and 
accepts the
 #'              same options as the CSV data source.
 #'          \item \code{arrays_zip}, this contains additional Columns of 
arrays to be merged.
+#'          \item \code{map_concat}, this contains additional Columns of maps 
to be unioned.
 #'          }
 #' @name column_collection_functions
 #' @rdname column_collection_functions
@@ -229,7 +235,7 @@ NULL
 #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
 #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), 
array_distinct(tmp$v1)))
 #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), 
array_sort(tmp$v1)))
-#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 
21)))
+#' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21)))
 #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
 #' head(tmp2)
 #' head(select(tmp, posexplode(tmp$v1)))
@@ -238,15 +244,21 @@ NULL
 #' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
 #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
 #' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), 
map_values(tmp3$v3)))
-#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
+#' head(select(tmp3, element_at(tmp3$v3, "Valiant"), map_concat(tmp3$v3, 
tmp3$v3)))
 #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = 
create_array(df$cyl, df$hp))
 #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, 
tmp4$v5)))
 #' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, 
tmp4$v5)))
 #' head(select(tmp4, array_union(tmp4$v4, tmp4$v5)))
-#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, 
tmp4$v5)))
+#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5)))
 #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
 #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
-#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", 
"NULL")))}
+#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", 
"NULL")))
+#' tmp6 <- mutate(df, v7 = create_array(create_array(df$model, df$model)))
+#' head(select(tmp6, flatten(tmp6$v7)))
+#' tmp7 <- mutate(df, v8 = create_array(df$model, df$cyl), v9 = 
create_array(df$model, df$hp))
+#' head(select(tmp7, map_from_arrays(tmp7$v8, tmp7$v9)))
+#' tmp8 <- mutate(df, v10 = create_array(struct(df$model, df$cyl)))
+#' head(select(tmp8, map_from_entries(tmp8$v10)))}
 NULL
 
 #' Window functions for Column operations
@@ -2074,15 +2086,21 @@ setMethod("levenshtein", signature(y = "Column"),
 #' are on the same day of month, or both are the last day of month, time of 
day will be ignored.
 #' Otherwise, the difference is calculated based on 31 days per month, and 
rounded to 8 digits.
 #'
+#' @param roundOff an optional parameter to specify if the result is rounded 
off to 8 digits
 #' @rdname column_datetime_diff_functions
 #' @aliases months_between months_between,Column-method
 #' @note months_between since 1.5.0
 setMethod("months_between", signature(y = "Column"),
-          function(y, x) {
+          function(y, x, roundOff = NULL) {
             if (class(x) == "Column") {
               x <- x@jc
             }
-            jc <- callJStatic("org.apache.spark.sql.functions", 
"months_between", y@jc, x)
+            jc <- if (is.null(roundOff)) {
+              callJStatic("org.apache.spark.sql.functions", "months_between", 
y@jc, x)
+            } else {
+              callJStatic("org.apache.spark.sql.functions", "months_between", 
y@jc, x,
+                           as.logical(roundOff))
+            }
             column(jc)
           })
 
@@ -3449,6 +3467,23 @@ setMethod("flatten",
           })
 
 #' @details
+#' \code{map_concat}: Returns the union of all the given maps.
+#'
+#' @rdname column_collection_functions
+#' @aliases map_concat map_concat,Column-method
+#' @note map_concat since 3.0.0
+setMethod("map_concat",
+          signature(x = "Column"),
+          function(x, ...) {
+            jcols <- lapply(list(x, ...), function(arg) {
+              stopifnot(class(arg) == "Column")
+              arg@jc
+            })
+            jc <- callJStatic("org.apache.spark.sql.functions", "map_concat", 
jcols)
+            column(jc)
+          })
+
+#' @details
 #' \code{map_entries}: Returns an unordered array of all entries in the given 
map.
 #'
 #' @rdname column_collection_functions
@@ -3477,6 +3512,19 @@ setMethod("map_from_arrays",
          })
 
 #' @details
+#' \code{map_from_entries}: Returns a map created from the given array of 
entries.
+#'
+#' @rdname column_collection_functions
+#' @aliases map_from_entries map_from_entries,Column-method
+#' @note map_from_entries since 3.0.0
+setMethod("map_from_entries",
+          signature(x = "Column"),
+          function(x) {
+            jc <- callJStatic("org.apache.spark.sql.functions", 
"map_from_entries", x@jc)
+            column(jc)
+         })
+
+#' @details
 #' \code{map_keys}: Returns an unordered array containing the keys of the map.
 #'
 #' @rdname column_collection_functions
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 09d8171..fcb511e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1080,6 +1080,10 @@ setGeneric("ltrim", function(x, trimString) { 
standardGeneric("ltrim") })
 
 #' @rdname column_collection_functions
 #' @name NULL
+setGeneric("map_concat", function(x, ...) { standardGeneric("map_concat") })
+
+#' @rdname column_collection_functions
+#' @name NULL
 setGeneric("map_entries", function(x) { standardGeneric("map_entries") })
 
 #' @rdname column_collection_functions
@@ -1088,6 +1092,10 @@ setGeneric("map_from_arrays", function(x, y) { 
standardGeneric("map_from_arrays"
 
 #' @rdname column_collection_functions
 #' @name NULL
+setGeneric("map_from_entries", function(x) { 
standardGeneric("map_from_entries") })
+
+#' @rdname column_collection_functions
+#' @name NULL
 setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
 
 #' @rdname column_collection_functions
@@ -1113,7 +1121,7 @@ setGeneric("month", function(x) { 
standardGeneric("month") })
 
 #' @rdname column_datetime_diff_functions
 #' @name NULL
-setGeneric("months_between", function(y, x) { 
standardGeneric("months_between") })
+setGeneric("months_between", function(y, x, ...) { 
standardGeneric("months_between") })
 
 #' @rdname count
 setGeneric("n", function(x) { standardGeneric("n") })
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index 93cb890..a5dde20 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1497,6 +1497,14 @@ test_that("column functions", {
   df5 <- createDataFrame(list(list(a = "010101")))
   expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")
 
+  # Test months_between()
+  df <- createDataFrame(list(list(a = as.Date("1997-02-28"),
+                                  b = as.Date("1996-10-30"))))
+  result1 <- collect(select(df, alias(months_between(df[[1]], df[[2]]), 
"month")))[[1]]
+  expect_equal(result1, 3.93548387)
+  result2 <- collect(select(df, alias(months_between(df[[1]], df[[2]], FALSE), 
"month")))[[1]]
+  expect_equal(result2, 3.935483870967742)
+
   # Test array_contains(), array_max(), array_min(), array_position(), 
element_at() and reverse()
   df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
   result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
@@ -1542,6 +1550,13 @@ test_that("column functions", {
   expected_entries <- list(as.environment(list(x = 1, y = 2)))
   expect_equal(result, expected_entries)
 
+  # Test map_from_entries()
+  df <- createDataFrame(list(list(list(listToStruct(list(c1 = "x", c2 = 1L)),
+                                       listToStruct(list(c1 = "y", c2 = 
2L))))))
+  result <- collect(select(df, map_from_entries(df[[1]])))[[1]]
+  expected_entries <- list(as.environment(list(x = 1L, y = 2L)))
+  expect_equal(result, expected_entries)
+
   # Test array_repeat()
   df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
   result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]
@@ -1600,6 +1615,13 @@ test_that("column functions", {
   result <- collect(select(df, flatten(df[[1]])))[[1]]
   expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))
 
+  # Test map_concat
+  df <- createDataFrame(list(list(map1 = as.environment(list(x = 1, y = 2)),
+                                  map2 = as.environment(list(a = 3, b = 4)))))
+  result <- collect(select(df, map_concat(df[[1]], df[[2]])))[[1]]
+  expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4)))
+  expect_equal(result, expected_entries)
+
   # Test map_entries(), map_keys(), map_values() and element_at()
   df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
   result <- collect(select(df, map_entries(df$map)))[[1]]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to