Repository: spark
Updated Branches:
  refs/heads/master 23256be0d -> 9e785079b


[SPARK-12235][SPARKR] Enhance mutate() to support replace existing columns.

Make the behavior of mutate more consistent with that in dplyr, besides support 
for replacing existing columns.
1. Throw error message when there are duplicated column names in the DataFrame 
being mutated.
2. when there are duplicated column names in specified columns by arguments, 
the last column of the same name takes effect.

Author: Sun Rui <rui....@intel.com>

Closes #10220 from sun-rui/SPARK-12235.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9e785079
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9e785079
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9e785079

Branch: refs/heads/master
Commit: 9e785079b6ed4ea691c3c14c762a7f73fb6254bf
Parents: 23256be
Author: Sun Rui <rui....@intel.com>
Authored: Thu Apr 28 09:33:58 2016 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Thu Apr 28 09:33:58 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/DataFrame.R                       | 60 ++++++++++++++++++++++----
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 ++++++++
 2 files changed, 69 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9e785079/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 48ac1b0..a741fdf 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1431,11 +1431,11 @@ setMethod("withColumn",
 
 #' Mutate
 #'
-#' Return a new SparkDataFrame with the specified columns added.
+#' Return a new SparkDataFrame with the specified columns added or replaced.
 #'
 #' @param .data A SparkDataFrame
 #' @param col a named argument of the form name = col
-#' @return A new SparkDataFrame with the new columns added.
+#' @return A new SparkDataFrame with the new columns added or replaced.
 #' @family SparkDataFrame functions
 #' @rdname mutate
 #' @name mutate
@@ -1450,23 +1450,65 @@ setMethod("withColumn",
 #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
 #' names(newDF) # Will contain newCol, newCol2
 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2)
+#'
+#' df <- createDataFrame(sqlContext, 
+#'                       list(list("Andy", 30L), list("Justin", 19L)), 
c("name", "age"))
+#' # Replace the "age" column
+#' df1 <- mutate(df, age = df$age + 1L)
 #' }
 setMethod("mutate",
           signature(.data = "SparkDataFrame"),
           function(.data, ...) {
             x <- .data
             cols <- list(...)
-            stopifnot(length(cols) > 0)
-            stopifnot(class(cols[[1]]) == "Column")
+            if (length(cols) <= 0) {
+              return(x)
+            }
+
+            lapply(cols, function(col) {
+              stopifnot(class(col) == "Column")
+            })
+
+            # Check if there is any duplicated column name in the DataFrame
+            dfCols <- columns(x)
+            if (length(unique(dfCols)) != length(dfCols)) {
+              stop("Error: found duplicated column name in the DataFrame")
+            }
+
+            # TODO: simplify the implementation of this method after 
SPARK-12225 is resolved.
+
+            # For named arguments, use the names for arguments as the column 
names
+            # For unnamed arguments, use the argument symbols as the column 
names
+            args <- sapply(substitute(list(...))[-1], deparse)
             ns <- names(cols)
             if (!is.null(ns)) {
-              for (n in ns) {
-                if (n != "") {
-                  cols[[n]] <- alias(cols[[n]], n)
+              lapply(seq_along(args), function(i) {
+                if (ns[[i]] != "") {
+                  args[[i]] <<- ns[[i]]
                 }
-              }
+              })
+            }
+            ns <- args
+
+            # The last column of the same name in the specific columns takes 
effect
+            deDupCols <- list()
+            for (i in 1:length(cols)) {
+              deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]])
             }
-            do.call(select, c(x, x$"*", cols))
+
+            # Construct the column list for projection
+            colList <- lapply(dfCols, function(col) {
+              if (!is.null(deDupCols[[col]])) {
+                # Replace existing column
+                tmpCol <- deDupCols[[col]]
+                deDupCols[[col]] <<- NULL
+                tmpCol
+              } else {
+                col(col)
+              }
+            })
+
+            do.call(select, c(x, colList, deDupCols))
           })
 
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/9e785079/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 95d6cb8..7058265 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", 
{
   expect_equal(columns(newDF)[3], "newAge")
   expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32)
 
+  newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3)
+  expect_equal(length(columns(newDF)), 3)
+  expect_equal(columns(newDF)[3], "newAge")
+  expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33)
+  expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32)
+
+  newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3,
+                  age = df$age + 4, newAge = df$age + 5)
+  expect_equal(length(columns(newDF)), 3)
+  expect_equal(columns(newDF)[3], "newAge")
+  expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35)
+  expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34)
+
+  newDF <- mutate(df, df$age + 3)
+  expect_equal(length(columns(newDF)), 3)
+  expect_equal(columns(newDF)[[3]], "df$age + 3")
+  expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33)
+
   newDF2 <- rename(df, newerAge = df$age)
   expect_equal(length(columns(newDF2)), 2)
   expect_equal(columns(newDF2)[1], "newerAge")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to