Repository: spark
Updated Branches:
  refs/heads/master 929404498 -> 68d1faa3c


[SPARK-6562][SQL] DataFrame.replace

Supports replacing values with other values in DataFrames.

Python support should be in a separate pull request.

Author: Reynold Xin <r...@databricks.com>

Closes #5282 from rxin/df-na-replace and squashes the following commits:

4b72434 [Reynold Xin] Removed println.
c8d9946 [Reynold Xin] col -> cols
fbb3c21 [Reynold Xin] [SPARK-6562][SQL] DataFrame.replace


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/68d1faa3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/68d1faa3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/68d1faa3

Branch: refs/heads/master
Commit: 68d1faa3c04e9412bbc2b60421dc12bd19c396b2
Parents: 9294044
Author: Reynold Xin <r...@databricks.com>
Authored: Sun Apr 12 22:56:12 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Sun Apr 12 22:56:12 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/DataFrameNaFunctions.scala | 144 +++++++++++++++++++
 .../spark/sql/DataFrameNaFunctionsSuite.scala   |  34 +++++
 2 files changed, 178 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/68d1faa3/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index bf3c3fe..481ed49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    */
   def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
 
+  /**
+   * Replaces values matching keys in `replacement` map with the corresponding 
values.
+   * Key and value of `replacement` map must have the same type, and can only 
be doubles or strings.
+   * If `col` is "*", then the replacement is applied on all string columns or 
numeric columns.
+   *
+   * {{{
+   *   import com.google.common.collect.ImmutableMap;
+   *
+   *   // Replaces all occurrences of 1.0 with 2.0 in column "height".
+   *   df.replace("height", ImmutableMap.of(1.0, 2.0));
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column 
"name".
+   *   df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string 
columns.
+   *   df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
+   * }}}
+   *
+   * @param col name of the column to apply the value replacement
+   * @param replacement value replacement map, as explained above
+   */
+  def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
+    replace[T](col, replacement.toMap : Map[T, T])
+  }
+
+  /**
+   * Replaces values matching keys in `replacement` map with the corresponding 
values.
+   * Key and value of `replacement` map must have the same type, and can only 
be doubles or strings.
+   *
+   * {{{
+   *   import com.google.common.collect.ImmutableMap;
+   *
+   *   // Replaces all occurrences of 1.0 with 2.0 in column "height" and 
"weight".
+   *   df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 
2.0));
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column 
"firstname" and "lastname".
+   *   df.replace(new String[] {"firstname", "lastname"}, 
ImmutableMap.of("UNKNOWN", "unnamed"));
+   * }}}
+   *
+   * @param cols list of columns to apply the value replacement
+   * @param replacement value replacement map, as explained above
+   */
+  def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): 
DataFrame = {
+    replace(cols.toSeq, replacement.toMap)
+  }
+
+  /**
+   * (Scala-specific) Replaces values matching keys in `replacement` map.
+   * Key and value of `replacement` map must have the same type, and can only 
be doubles or strings.
+   * If `col` is "*", then the replacement is applied on all string columns or 
numeric columns.
+   *
+   * {{{
+   *   // Replaces all occurrences of 1.0 with 2.0 in column "height".
+   *   df.replace("height", Map(1.0 -> 2.0))
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column 
"name".
+   *   df.replace("name", Map("UNKNOWN" -> "unnamed")
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string 
columns.
+   *   df.replace("*", Map("UNKNOWN" -> "unnamed")
+   * }}}
+   *
+   * @param col name of the column to apply the value replacement
+   * @param replacement value replacement map, as explained above
+   */
+  def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
+    if (col == "*") {
+      replace0(df.columns, replacement)
+    } else {
+      replace0(Seq(col), replacement)
+    }
+  }
+
+  /**
+   * (Scala-specific) Replaces values matching keys in `replacement` map.
+   * Key and value of `replacement` map must have the same type, and can only 
be doubles or strings.
+   *
+   * {{{
+   *   // Replaces all occurrences of 1.0 with 2.0 in column "height" and 
"weight".
+   *   df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
+   *
+   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column 
"firstname" and "lastname".
+   *   df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> 
"unnamed");
+   * }}}
+   *
+   * @param cols list of columns to apply the value replacement
+   * @param replacement value replacement map, as explained above
+   */
+  def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = 
replace0(cols, replacement)
+
+  private def replace0[T](cols: Seq[String], replacement: Map[T, T]): 
DataFrame = {
+    if (replacement.isEmpty || cols.isEmpty) {
+      return df
+    }
+
+    // replacementMap is either Map[String, String] or Map[Double, Double]
+    val replacementMap: Map[_, _] = replacement.head._2 match {
+      case v: String => replacement
+      case _ => replacement.map { case (k, v) => (convertToDouble(k), 
convertToDouble(v)) }
+    }
+
+    // targetColumnType is either DoubleType or StringType
+    val targetColumnType = replacement.head._1 match {
+      case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => 
DoubleType
+      case _: String => StringType
+    }
+
+    val columnEquals = df.sqlContext.analyzer.resolver
+    val projections = df.schema.fields.map { f =>
+      val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
+      if (f.dataType.isInstanceOf[NumericType] && targetColumnType == 
DoubleType && shouldReplace) {
+        replaceCol(f, replacementMap)
+      } else if (f.dataType == targetColumnType && shouldReplace) {
+        replaceCol(f, replacementMap)
+      } else {
+        df.col(f.name)
+      }
+    }
+    df.select(projections : _*)
+  }
+
   private def fill0(values: Seq[(String, Any)]): DataFrame = {
     // Error handling
     values.foreach { case (colName, replaceValue) =>
@@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
   private def fillCol[T](col: StructField, replacement: T): Column = {
     coalesce(df.col(col.name), 
lit(replacement).cast(col.dataType)).as(col.name)
   }
+
+  /**
+   * Returns a [[Column]] expression that replaces value matching key in 
`replacementMap` with
+   * value in `replacementMap`, using [[CaseWhen]].
+   *
+   * TODO: This can be optimized to use broadcast join when replacementMap is 
large.
+   */
+  private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column 
= {
+    val branches: Seq[Expression] = replacementMap.flatMap { case (source, 
target) =>
+      df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
+        lit(target).cast(col.dataType).expr :: Nil
+    }.toSeq
+    new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
+  }
+
+  private def convertToDouble(v: Any): Double = v match {
+    case v: Float => v.toDouble
+    case v: Double => v
+    case v: Long => v.toDouble
+    case v: Int => v.toDouble
+    case v => throw new IllegalArgumentException(
+      s"Unsupported value type ${v.getClass.getName} ($v).")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/68d1faa3/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 0896f17..41b4f02 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -154,4 +154,38 @@ class DataFrameNaFunctionsSuite extends QueryTest {
       ))),
       Row("test", null, 1, 2.2))
   }
+
+  test("replace") {
+    val input = createDF()
+
+    // Replace two numeric columns: age and height
+    val out = input.na.replace(Seq("age", "height"), Map(
+      16 -> 61,
+      60 -> 6,
+      164.3 -> 461.3  // Alice is really tall
+    ))
+
+    checkAnswer(
+      out,
+      Row("Bob", 61, 176.5) ::
+        Row("Alice", null, 461.3) ::
+        Row("David", 6, null) ::
+        Row("Amy", null, null) ::
+        Row(null, null, null) :: Nil)
+
+    // Replace only the age column
+    val out1 = input.na.replace("age", Map(
+      16 -> 61,
+      60 -> 6,
+      164.3 -> 461.3  // Alice is really tall
+    ))
+
+    checkAnswer(
+      out1,
+      Row("Bob", 61, 176.5) ::
+        Row("Alice", null, 164.3) ::
+        Row("David", 6, null) ::
+        Row("Amy", null, null) ::
+        Row(null, null, null) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to