This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 032e78297b0 [SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` should respect the dict ordering 032e78297b0 is described below commit 032e78297b02adb4266818776b55e09057705084 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Dec 6 17:16:07 2023 +0900 [SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` should respect the dict ordering ### What changes were proposed in this pull request? Make `DataFrame.withColumnsRenamed` respect the dict ordering ### Why are the changes needed? the ordering in `withColumnsRenamed` matters in scala ``` scala> val df = spark.range(1000) val df: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> df.withColumnsRenamed(Map("id" -> "a", "a" -> "b")) val res0: org.apache.spark.sql.DataFrame = [b: bigint] scala> df.withColumnsRenamed(Map("a" -> "b", "id" -> "a")) val res1: org.apache.spark.sql.DataFrame = [a: bigint] ``` However, in py4j the Python `dict` -> JVM `map` conversion can not guarantee the ordering ### Does this PR introduce _any_ user-facing change? yes, behavior change before this PR ``` In [1]: df = spark.range(10) In [2]: df.withColumnsRenamed({"id": "a", "a": "b"}) Out[2]: DataFrame[a: bigint] In [3]: df.withColumnsRenamed({"a": "b", "id": "a"}) Out[3]: DataFrame[a: bigint] ``` after this PR ``` In [1]: df = spark.range(10) In [2]: df.withColumnsRenamed({"id": "a", "a": "b"}) Out[2]: DataFrame[b: bigint] In [3]: df.withColumnsRenamed({"a": "b", "id": "a"}) Out[3]: DataFrame[a: bigint] ``` ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44177 from zhengruifeng/sql_withColumnsRenamed_sql. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/dataframe.py | 13 ++++++++++- .../sql/tests/connect/test_parity_dataframe.py | 5 ++++ python/pyspark/sql/tests/test_dataframe.py | 9 ++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 27 +++++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++ 5 files changed, 52 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5211d874ba3..1419d1f3cb6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6272,7 +6272,18 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, ) - return DataFrame(self._jdf.withColumnsRenamed(colsMap), self.sparkSession) + col_names: List[str] = [] + new_col_names: List[str] = [] + for k, v in colsMap.items(): + col_names.append(k) + new_col_names.append(v) + + return DataFrame( + self._jdf.withColumnsRenamed( + _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names) + ), + self.sparkSession, + ) def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index b7b4fdcd287..fbef282e0b9 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -77,6 +77,11 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_toDF_with_string(self): super().test_toDF_with_string() + # TODO(SPARK-46261): Python Client withColumnsRenamed should respect the dict ordering + @unittest.skip("Fails in Spark Connect, should enable.") + def test_ordering_of_with_columns_renamed(self): + super().test_ordering_of_with_columns_renamed() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 52806f4f4a3..c25fe60ad17 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -163,6 +163,15 @@ class DataFrameTestsMixin: message_parameters={"arg_name": "colsMap", "arg_type": "tuple"}, ) + def test_ordering_of_with_columns_renamed(self): + df = self.spark.range(10) + + df1 = df.withColumnsRenamed({"id": "a", "a": "b"}) + self.assertEqual(df1.columns, ["b"]) + + df2 = df.withColumnsRenamed({"a": "b", "id": "a"}) + self.assertEqual(df2.columns, ["a"]) + def test_drop_duplicates(self): # SPARK-36034 test that drop duplicates throws a type error when in correct type provided df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 293f20c453a..cacc193885d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2922,18 +2922,29 @@ class Dataset[T] private[sql]( */ @throws[AnalysisException] def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { + val (colNames, newColNames) = colsMap.toSeq.unzip + withColumnsRenamed(colNames, newColNames) + } + + private def withColumnsRenamed( + colNames: Seq[String], + newColNames: Seq[String]): DataFrame = withOrigin { + require(colNames.size == newColNames.size, + s"The size of existing column names: ${colNames.size} isn't equal to " + + s"the size of new column names: ${newColNames.size}") + val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output - val projectList = colsMap.foldLeft(output) { + val projectList = colNames.zip(newColNames).foldLeft(output) { case (attrs, (existingName, newName)) => - attrs.map(attr => - if (resolver(attr.name, existingName)) { - Alias(attr, newName)() - } else { - attr - } - ) + attrs.map(attr => + if (resolver(attr.name, existingName)) { + Alias(attr, newName)() + } else { + attr + } + ) } SchemaUtils.checkColumnNameDuplication( projectList.map(_.name), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b732f6631a7..25ecefd28cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} import java.util.concurrent.atomic.AtomicLong +import scala.collection.immutable.ListMap import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -987,6 +988,12 @@ class DataFrameSuite extends QueryTest parameters = Map("columnName" -> "`age`")) } + test("SPARK-46260: withColumnsRenamed should respect the Map ordering") { + val df = spark.range(10).toDF() + assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns === Array("b")) + assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns === Array("a")) + } + test("SPARK-20384: Value class filter") { val df = spark.sparkContext .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org