Repository: spark Updated Branches: refs/heads/master 1da5822e6 -> a8d9ec8a6
[SPARK-21780][R] Simpler Dataset.sample API in R ## What changes were proposed in this pull request? This PR make `sample(...)` able to omit `withReplacement` defaulting to `FALSE`. In short, the following examples are allowed: ```r > df <- createDataFrame(as.list(seq(10))) > count(sample(df, fraction=0.5, seed=3)) [1] 4 > count(sample(df, fraction=1.0)) [1] 10 ``` In addition, this PR also adds some type checking logics as below: ```r > sample(df, fraction = "a") Error in sample(df, fraction = "a") : fraction must be numeric; however, got character > sample(df, fraction = 1, seed = NULL) Error in sample(df, fraction = 1, seed = NULL) : seed must not be NULL or NA; however, got NULL > sample(df, list(1), 1.0) Error in sample(df, list(1), 1) : withReplacement must be logical; however, got list > sample(df, fraction = -1.0) ... Error in sample : illegal argument - requirement failed: Sampling fraction (-1.0) must be on interval [0, 1] without replacement ``` ## How was this patch tested? Manually tested, unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon <gurwls...@gmail.com> Closes #19243 from HyukjinKwon/SPARK-21780. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8d9ec8a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8d9ec8a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8d9ec8a Branch: refs/heads/master Commit: a8d9ec8a60f21abb520b9109b238f914d2449022 Parents: 1da5822 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Thu Sep 21 20:16:25 2017 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Thu Sep 21 20:16:25 2017 +0900 ---------------------------------------------------------------------- R/pkg/R/DataFrame.R | 40 ++++++++++++++++++++---------- R/pkg/R/generics.R | 4 +-- R/pkg/tests/fulltests/test_sparkSQL.R | 14 +++++++++++ 3 files changed, 43 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/DataFrame.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1b46c1e..0728141 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -986,10 +986,10 @@ setMethod("unique", #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @param seed Randomness seed value +#' @param seed Randomness seed value. Default is a random seed. #' #' @family SparkDataFrame functions -#' @aliases sample,SparkDataFrame,logical,numeric-method +#' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample #' @export @@ -998,33 +998,47 @@ setMethod("unique", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' collect(sample(df, fraction = 0.5)) #' collect(sample(df, FALSE, 0.5)) -#' collect(sample(df, TRUE, 0.5)) +#' collect(sample(df, TRUE, 0.5, seed = 3)) #'} #' @note sample since 1.4.0 setMethod("sample", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { - if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { + if (!is.numeric(fraction)) { + stop(paste("fraction must be numeric; however, got", class(fraction))) + } + if (!is.logical(withReplacement)) { + stop(paste("withReplacement must be logical; however, got", class(withReplacement))) + } + if (!missing(seed)) { + if (is.null(seed)) { + stop("seed must not be NULL or NA; however, got NULL") + } + if (is.na(seed)) { + stop("seed must not be NULL or NA; however, got NA") + } + # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + sdf <- handledCallJMethod(x@sdf, "sample", as.logical(withReplacement), + as.numeric(fraction), as.integer(seed)) } else { - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + sdf <- handledCallJMethod(x@sdf, "sample", + as.logical(withReplacement), as.numeric(fraction)) } dataFrame(sdf) }) #' @rdname sample -#' @aliases sample_frac,SparkDataFrame,logical,numeric-method +#' @aliases sample_frac,SparkDataFrame-method #' @name sample_frac #' @note sample_frac since 1.4.0 setMethod("sample_frac", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { sample(x, withReplacement, fraction, seed) }) http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 603ff4e..0fe8f04 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -645,7 +645,7 @@ setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample #' @export setGeneric("sample", - function(x, withReplacement, fraction, seed) { + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) @@ -656,7 +656,7 @@ setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy #' @export http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/tests/fulltests/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 85a7e08..4d1010e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1116,6 +1116,20 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # Different arguments + df <- createDataFrame(as.list(seq(10))) + expect_equal(count(sample(df, fraction = 0.5, seed = 3)), 4) + expect_equal(count(sample(df, withReplacement = TRUE, fraction = 0.5, seed = 3)), 2) + expect_equal(count(sample(df, fraction = 1.0)), 10) + expect_equal(count(sample(df, fraction = 1L)), 10) + expect_equal(count(sample(df, FALSE, fraction = 1.0)), 10) + + expect_error(sample(df, fraction = "a"), "fraction must be numeric") + expect_error(sample(df, "a", fraction = 0.1), "however, got character") + expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA; however, got NA") + expect_error(sample(df, fraction = -1.0), + "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)") + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org