Repository: spark Updated Branches: refs/heads/master c303b1b67 -> 87e8a572b
[SPARK-24054][R] Add array_position function / element_at functions ## What changes were proposed in this pull request? This PR proposes to add array_position and element_at in R side too. array_position: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_position(mutated$v1, 1))) ``` ``` array_position(v1, 1.0) 1 2 2 2 3 2 4 3 5 0 6 3 ``` element_at: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, element_at(mutated$v1, 1))) ``` ``` element_at(v1, 1.0) 1 21.0 2 21.0 3 22.8 4 21.4 5 18.7 6 18.1 ``` ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_map(df$model, df$cyl)) head(select(mutated, element_at(mutated$v1, "Valiant"))) ``` ``` element_at(v3, Valiant) 1 NA 2 NA 3 NA 4 NA 5 NA 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon <gurwls...@apache.org> Closes #21130 from HyukjinKwon/sparkr_array_position_element_at. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/87e8a572 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/87e8a572 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/87e8a572 Branch: refs/heads/master Commit: 87e8a572be14381da9081365d9aa2cbf3253a32c Parents: c303b1b Author: hyukjinkwon <gurwls...@apache.org> Authored: Tue Apr 24 16:18:20 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Tue Apr 24 16:18:20 2018 +0800 ---------------------------------------------------------------------- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 42 ++++++++++++++++++++++++++++-- R/pkg/R/generics.R | 8 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++-- 4 files changed, 61 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/NAMESPACE ---------------------------------------------------------------------- diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50e..55dec17 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_position", "asc", "ascii", "asin", @@ -245,6 +246,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/R/functions.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426..7b3aa05 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +214,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -2975,7 +2982,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2987,6 +2993,22 @@ setMethod("array_contains", }) #' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + +#' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' #' @rdname column_collection_functions @@ -3013,6 +3035,22 @@ setMethod("map_values", }) #' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + +#' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' #' @rdname column_collection_functions http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff..f30ac9e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) http://git-wip-us.apache.org/repos/asf/spark/blob/87e8a572/R/pkg/tests/fulltests/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469..a384997 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,17 +1479,23 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_position(), element_at() and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1497,6 +1503,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org