This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a9c1189 [SPARK-34649][SQL][DOCS] org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name having a dot a9c1189 is described below commit a9c11896a5db3cd6844d5e444ad59e65d9441e7c Author: Amandeep Sharma <happyama...@gmail.com> AuthorDate: Tue Mar 9 11:47:01 2021 +0000 [SPARK-34649][SQL][DOCS] org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name having a dot ### What changes were proposed in this pull request? Use resolved attributes instead of data-frame fields for replacing values. ### Why are the changes needed? dataframe.na.replace() does not work for column having a dot in the name ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Added unit tests for the same Closes #31769 from amandeep-sharma/master. Authored-by: Amandeep Sharma <happyama...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- docs/sql-migration-guide.md | 2 + .../apache/spark/sql/DataFrameNaFunctions.scala | 42 ++++++++++++--------- .../spark/sql/DataFrameNaFunctionsSuite.scala | 43 ++++++++++++++++++++-- 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 0e96c6d..5551d56 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -66,6 +66,8 @@ license: | - In Spark 3.2, the output schema of `SHOW TBLPROPERTIES` becomes `key: string, value: string` whether you specify the table property key or not. In Spark 3.1 and earlier, the output schema of `SHOW TBLPROPERTIES` is `value: string` when you specify the table property key. To restore the old schema with the builtin catalog, you can set `spark.sql.legacy.keepCommandOutputSchema` to `true`. - In Spark 3.2, we support typed literals in the partition spec of INSERT and ADD/DROP/RENAME PARTITION. For example, `ADD PARTITION(dt = date'2020-01-01')` adds a partition with date value `2020-01-01`. In Spark 3.1 and earlier, the partition value will be parsed as string value `date '2020-01-01'`, which is an illegal date value, and we add a partition with null value at the end. + + - In Spark 3.2, `DataFrameNaFunctions.replace()` no longer uses exact string match for the input column names, to match the SQL syntax and support qualified column names. Input column name having a dot in the name (not nested) needs to be escaped with backtick \`. Now, it throws `AnalysisException` if the column is not found in the data frame schema. It also throws `IllegalArgumentException` if the input column name is a nested column. In Spark 3.1 and earlier, it used to ignore invali [...] ## Upgrading from Spark SQL 3.0 to 3.1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 308bb96..91905f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -327,9 +327,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def replace[T](col: String, replacement: Map[T, T]): DataFrame = { if (col == "*") { - replace0(df.columns, replacement) + replace0(df.logicalPlan.output, replacement) } else { - replace0(Seq(col), replacement) + replace(Seq(col), replacement) } } @@ -352,10 +352,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { + val attrs = cols.map { colName => + // Check column name exists + val attr = df.resolve(colName) match { + case a: Attribute => a + case _ => throw new UnsupportedOperationException( + s"Nested field ${colName} is not supported.") + } + attr + } + replace0(attrs, replacement) + } - private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { - if (replacement.isEmpty || cols.isEmpty) { + private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]): DataFrame = { + if (replacement.isEmpty || attrs.isEmpty) { return df } @@ -379,15 +390,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case _: String => StringType } - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) - if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { - replaceCol(f, replacementMap) - } else if (f.dataType == targetColumnType && shouldReplace) { - replaceCol(f, replacementMap) + val output = df.queryExecution.analyzed.output + val projections = output.map { attr => + if (attrs.contains(attr) && (attr.dataType == targetColumnType || + (attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) { + replaceCol(attr, replacementMap) } else { - df.col(f.name) + Column(attr) } } df.select(projections : _*) @@ -453,13 +462,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * TODO: This can be optimized to use broadcast join when replacementMap is large. */ - private def replaceCol[K, V](col: StructField, replacementMap: Map[K, V]): Column = { - val keyExpr = df.col(col.name).expr - def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) + private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]): Column = { + def buildExpr(v: Any) = Cast(Literal(v), attr.dataType) val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) + new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 23c2349..20ae995 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -461,7 +461,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) } - test("SPARK-34417 - test fillMap() for column with a dot in the name") { + test("SPARK-34417: test fillMap() for column with a dot in the name") { val na = "n/a" checkAnswer( Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col") @@ -469,7 +469,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) } - test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") { + test("SPARK-34417: test fillMap() for qualified-column with a dot in the name") { val na = "n/a" checkAnswer( Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF") @@ -477,7 +477,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) } - test("SPARK-34417 - test fillMap() for column without a dot in the name" + + test("SPARK-34417: test fillMap() for column without a dot in the name" + " and dataframe with another column having a dot in the name") { val na = "n/a" checkAnswer( @@ -485,4 +485,41 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { .na.fill(Map("Col" -> na)), Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) } + + test("SPARK-34649: replace value of a column with dot in the name") { + checkAnswer( + Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2") + .na.replace("`Col.1`", Map( "n/a" -> "unknown")), + Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil) + } + + test("SPARK-34649: replace value of a qualified-column with dot in the name") { + checkAnswer( + Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2").as("testDf") + .na.replace("testDf.`Col.1`", Map( "n/a" -> "unknown")), + Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil) + } + + test("SPARK-34649: replace value of a dataframe having dot in the all column names") { + checkAnswer( + Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2") + .na.replace("*", Map( "n/a" -> "unknown")), + Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil) + } + + test("SPARK-34649: replace value of a column not present in the dataframe") { + val df = Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2") + val exception = intercept[AnalysisException] { + df.na.replace("aa", Map( "n/a" -> "unknown")) + } + assert(exception.getMessage.equals("Cannot resolve column name \"aa\" among (Col.1, Col.2)")) + } + + test("SPARK-34649: replace value of a nested column") { + val df = createDFWithNestedColumns + val exception = intercept[UnsupportedOperationException] { + df.na.replace("c1.c1-1", Map("b1" ->"a1")) + } + assert(exception.getMessage.equals("Nested field c1.c1-1 is not supported.")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org