Repository: spark
Updated Branches:
  refs/heads/master c303b1b67 -> 87e8a572b


[SPARK-24054][R] Add array_position function / element_at functions

## What changes were proposed in this pull request?

This PR proposes to add array_position and element_at in R side too.

array_position:

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb))
head(select(mutated, array_position(mutated$v1, 1)))
```

```
  array_position(v1, 1.0)
1                       2
2                       2
3                       2
4                       3
5                       0
6                       3
```

element_at:

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
head(select(mutated, element_at(mutated$v1, 1)))
```

```
  element_at(v1, 1.0)
1                21.0
2                21.0
3                22.8
4                21.4
5                18.7
6                18.1
```

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_map(df$model, df$cyl))
head(select(mutated, element_at(mutated$v1, "Valiant")))
```

```
  element_at(v3, Valiant)
1                      NA
2                      NA
3                      NA
4                      NA
5                      NA
6                       6
```

## How was this patch tested?

Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually 
tested. Documentation was manually built and verified.

Author: hyukjinkwon <gurwls...@apache.org>

Closes #21130 from HyukjinKwon/sparkr_array_position_element_at.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/87e8a572
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/87e8a572
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/87e8a572

Branch: refs/heads/master
Commit: 87e8a572be14381da9081365d9aa2cbf3253a32c
Parents: c303b1b
Author: hyukjinkwon <gurwls...@apache.org>
Authored: Tue Apr 24 16:18:20 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Tue Apr 24 16:18:20 2018 +0800

----------------------------------------------------------------------
 R/pkg/NAMESPACE                       |  2 ++
 R/pkg/R/functions.R                   | 42 ++++++++++++++++++++++++++++--
 R/pkg/R/generics.R                    |  8 ++++++
 R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++--
 4 files changed, 61 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 190c50e..55dec17 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -201,6 +201,7 @@ exportMethods("%<=>%",
               "approxCountDistinct",
               "approxQuantile",
               "array_contains",
+              "array_position",
               "asc",
               "ascii",
               "asin",
@@ -245,6 +246,7 @@ exportMethods("%<=>%",
               "decode",
               "dense_rank",
               "desc",
+              "element_at",
               "encode",
               "endsWith",
               "exp",

http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index a527426..7b3aa05 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -189,6 +189,11 @@ NULL
 #'              the map or array of maps.
 #'          \item \code{from_json}: it is the column containing the JSON 
string.
 #'          }
+#' @param value A value to compute on.
+#'          \itemize{
+#'          \item \code{array_contains}: a value to be checked if contained in 
the column.
+#'          \item \code{array_position}: a value to locate in the given array.
+#'          }
 #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, 
this contains
 #'            additional named properties to control how it is converted, 
accepts the same
 #'            options as the JSON data source.
@@ -201,6 +206,7 @@ NULL
 #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
 #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
 #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
+#' head(select(tmp, array_position(tmp$v1, 21)))
 #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
 #' head(tmp2)
 #' head(select(tmp, posexplode(tmp$v1)))
@@ -208,7 +214,8 @@ NULL
 #' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
 #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
 #' head(select(tmp3, map_keys(tmp3$v3)))
-#' head(select(tmp3, map_values(tmp3$v3)))}
+#' head(select(tmp3, map_values(tmp3$v3)))
+#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))}
 NULL
 
 #' Window functions for Column operations
@@ -2975,7 +2982,6 @@ setMethod("row_number",
 #' \code{array_contains}: Returns null if the array is null, true if the array 
contains
 #' the value, and false otherwise.
 #'
-#' @param value a value to be checked if contained in the column
 #' @rdname column_collection_functions
 #' @aliases array_contains array_contains,Column-method
 #' @note array_contains since 1.6.0
@@ -2987,6 +2993,22 @@ setMethod("array_contains",
           })
 
 #' @details
+#' \code{array_position}: Locates the position of the first occurrence of the 
given value
+#' in the given array. Returns NA if either of the arguments are NA.
+#' Note: The position is not zero based, but 1 based index. Returns 0 if the 
given
+#' value could not be found in the array.
+#'
+#' @rdname column_collection_functions
+#' @aliases array_position array_position,Column-method
+#' @note array_position since 2.4.0
+setMethod("array_position",
+          signature(x = "Column", value = "ANY"),
+          function(x, value) {
+            jc <- callJStatic("org.apache.spark.sql.functions", 
"array_position", x@jc, value)
+            column(jc)
+          })
+
+#' @details
 #' \code{map_keys}: Returns an unordered array containing the keys of the map.
 #'
 #' @rdname column_collection_functions
@@ -3013,6 +3035,22 @@ setMethod("map_values",
           })
 
 #' @details
+#' \code{element_at}: Returns element of array at given index in 
\code{extraction} if
+#' \code{x} is array. Returns value for the given key in \code{extraction} if 
\code{x} is map.
+#' Note: The position is not zero based, but 1 based index.
+#'
+#' @param extraction index to check for in array or key to check for in map
+#' @rdname column_collection_functions
+#' @aliases element_at element_at,Column-method
+#' @note element_at since 2.4.0
+setMethod("element_at",
+          signature(x = "Column", extraction = "ANY"),
+          function(x, extraction) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "element_at", 
x@jc, extraction)
+            column(jc)
+          })
+
+#' @details
 #' \code{explode}: Creates a new row for each element in the given array or 
map column.
 #'
 #' @rdname column_collection_functions

http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 974beff..f30ac9e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { 
standardGeneric("approxCoun
 #' @name NULL
 setGeneric("array_contains", function(x, value) { 
standardGeneric("array_contains") })
 
+#' @rdname column_collection_functions
+#' @name NULL
+setGeneric("array_position", function(x, value) { 
standardGeneric("array_position") })
+
 #' @rdname column_string_functions
 #' @name NULL
 setGeneric("ascii", function(x) { standardGeneric("ascii") })
@@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { 
standardGeneric("decode") })
 #' @name NULL
 setGeneric("dense_rank", function(x = "missing") { 
standardGeneric("dense_rank") })
 
+#' @rdname column_collection_functions
+#' @name NULL
+setGeneric("element_at", function(x, extraction) { 
standardGeneric("element_at") })
+
 #' @rdname column_string_functions
 #' @name NULL
 setGeneric("encode", function(x, charset) { standardGeneric("encode") })

http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/tests/fulltests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index 7105469..a384997 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1479,17 +1479,23 @@ test_that("column functions", {
   df5 <- createDataFrame(list(list(a = "010101")))
   expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")
 
-  # Test array_contains() and sort_array()
+  # Test array_contains(), array_position(), element_at() and sort_array()
   df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
   result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
   expect_equal(result, c(TRUE, FALSE))
 
+  result <- collect(select(df, array_position(df[[1]], 1L)))[[1]]
+  expect_equal(result, c(1, 0))
+
+  result <- collect(select(df, element_at(df[[1]], 1L)))[[1]]
+  expect_equal(result, c(1, 6))
+
   result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
   expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L)))
   result <- collect(select(df, sort_array(df[[1]])))[[1]]
   expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L)))
 
-  # Test map_keys() and map_values()
+  # Test map_keys(), map_values() and element_at()
   df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
   result <- collect(select(df, map_keys(df$map)))[[1]]
   expect_equal(result, list(list("x", "y")))
@@ -1497,6 +1503,9 @@ test_that("column functions", {
   result <- collect(select(df, map_values(df$map)))[[1]]
   expect_equal(result, list(list(1, 2)))
 
+  result <- collect(select(df, element_at(df$map, "y")))[[1]]
+  expect_equal(result, 2)
+
   # Test that stats::lag is working
   expect_equal(length(lag(ldeaths, 12)), 72)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to