Repository: spark Updated Branches: refs/heads/master 23256be0d -> 9e785079b
[SPARK-12235][SPARKR] Enhance mutate() to support replace existing columns. Make the behavior of mutate more consistent with that in dplyr, besides support for replacing existing columns. 1. Throw error message when there are duplicated column names in the DataFrame being mutated. 2. when there are duplicated column names in specified columns by arguments, the last column of the same name takes effect. Author: Sun Rui <rui....@intel.com> Closes #10220 from sun-rui/SPARK-12235. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9e785079 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9e785079 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9e785079 Branch: refs/heads/master Commit: 9e785079b6ed4ea691c3c14c762a7f73fb6254bf Parents: 23256be Author: Sun Rui <rui....@intel.com> Authored: Thu Apr 28 09:33:58 2016 -0700 Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu> Committed: Thu Apr 28 09:33:58 2016 -0700 ---------------------------------------------------------------------- R/pkg/R/DataFrame.R | 60 ++++++++++++++++++++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 ++++++++ 2 files changed, 69 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9e785079/R/pkg/R/DataFrame.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 48ac1b0..a741fdf 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1431,11 +1431,11 @@ setMethod("withColumn", #' Mutate #' -#' Return a new SparkDataFrame with the specified columns added. +#' Return a new SparkDataFrame with the specified columns added or replaced. #' #' @param .data A SparkDataFrame #' @param col a named argument of the form name = col -#' @return A new SparkDataFrame with the new columns added. +#' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions #' @rdname mutate #' @name mutate @@ -1450,23 +1450,65 @@ setMethod("withColumn", #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) +#' +#' df <- createDataFrame(sqlContext, +#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) +#' # Replace the "age" column +#' df1 <- mutate(df, age = df$age + 1L) #' } setMethod("mutate", signature(.data = "SparkDataFrame"), function(.data, ...) { x <- .data cols <- list(...) - stopifnot(length(cols) > 0) - stopifnot(class(cols[[1]]) == "Column") + if (length(cols) <= 0) { + return(x) + } + + lapply(cols, function(col) { + stopifnot(class(col) == "Column") + }) + + # Check if there is any duplicated column name in the DataFrame + dfCols <- columns(x) + if (length(unique(dfCols)) != length(dfCols)) { + stop("Error: found duplicated column name in the DataFrame") + } + + # TODO: simplify the implementation of this method after SPARK-12225 is resolved. + + # For named arguments, use the names for arguments as the column names + # For unnamed arguments, use the argument symbols as the column names + args <- sapply(substitute(list(...))[-1], deparse) ns <- names(cols) if (!is.null(ns)) { - for (n in ns) { - if (n != "") { - cols[[n]] <- alias(cols[[n]], n) + lapply(seq_along(args), function(i) { + if (ns[[i]] != "") { + args[[i]] <<- ns[[i]] } - } + }) + } + ns <- args + + # The last column of the same name in the specific columns takes effect + deDupCols <- list() + for (i in 1:length(cols)) { + deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]]) } - do.call(select, c(x, x$"*", cols)) + + # Construct the column list for projection + colList <- lapply(dfCols, function(col) { + if (!is.null(deDupCols[[col]])) { + # Replace existing column + tmpCol <- deDupCols[[col]] + deDupCols[[col]] <<- NULL + tmpCol + } else { + col(col) + } + }) + + do.call(select, c(x, colList, deDupCols)) }) #' @export http://git-wip-us.apache.org/repos/asf/spark/blob/9e785079/R/pkg/inst/tests/testthat/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 95d6cb8..7058265 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(newDF)[3], "newAge") expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3, + age = df$age + 4, newAge = df$age + 5) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34) + + newDF <- mutate(df, df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[[3]], "df$age + 3") + expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33) + newDF2 <- rename(df, newerAge = df$age) expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org