This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2c629020592 [SPARK-40311][SQL][PYTHON] Add `withColumnsRenamed` to scala and pyspark API 2c629020592 is described below commit 2c6290205928521e8d7404bb9a9cbccff0d35674 Author: santosh <3813695+santosh-d3vp...@users.noreply.github.com> AuthorDate: Thu Oct 6 00:29:07 2022 -0700 [SPARK-40311][SQL][PYTHON] Add `withColumnsRenamed` to scala and pyspark API ### What changes were proposed in this pull request? This change adds an ability for code to rename multiple columns in a single call. **Scala:** ```scala withColumnsRenamed(colsMap: Map[String, String]): DataFrame ``` **Java:** ```java withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame ``` **Python:** ```python withColumnsRenamed(self, *colsMap: Dict[str, Column]) -> "DataFrame" ``` ### Why are the changes needed? We have seen that catalyst optimiser struggles with bigger plans. The larger contribution to these plans in our setup comes from `withColumnRenamed`, `drop` and `withColumn` being called in for loop by unknowing users. `master` branch of spark already has a version for handling `withColumns` and `drop` for multiple columns. The missing bit of the puzzle is `withColumnRenamed`. With large amount of columns, either JVM gets killed or StackOverflowError occurs. I am skipping those for the following benchmark and focus on number of columns which work in both old and new implementation. Following example shows the performance impact with 100 columns.: **Old fashioned with 100 columns** ```python import datetime import numpy as np import pandas as pd num_rows = 2 num_columns = 100 data = np.zeros((num_rows, num_columns)) columns = map(str, range(num_columns)) raw = spark.createDataFrame(pd.DataFrame(data, columns=columns)) a = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") b = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") c = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") d = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") e = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") f = datetime.datetime.now() for col in raw.columns: raw = raw.withColumnRenamed(col, f"prefix_{col}") g = datetime.datetime.now() g-a datetime.timedelta(seconds=12, microseconds=480021) ``` **New implementation with 100 columns** ```python import datetime import numpy as np import pandas as pd num_rows = 2 num_columns = 100 data = np.zeros((num_rows, num_columns)) columns = map(str, range(num_columns)) raw = spark.createDataFrame(pd.DataFrame(data, columns=columns)) a = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) b = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) c = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) d = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) e = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) f = datetime.datetime.now() raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns}) g = datetime.datetime.now() g-a datetime.timedelta(microseconds=210400) ``` ### Does this PR introduce _any_ user-facing change? Yes, adds a method to efficiently rename columns in a single batch. ### How was this patch tested? Added unit tests Closes #37761 from santosh-d3vpl3x/master. Lead-authored-by: santosh <3813695+santosh-d3vp...@users.noreply.github.com> Co-authored-by: Santosh Pingale <3813695+santosh-d3vp...@users.noreply.github.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../source/reference/pyspark.sql/dataframe.rst | 1 + python/pyspark/sql/dataframe.py | 40 +++++++++++++++++ python/pyspark/sql/tests/test_dataframe.py | 16 +++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 47 ++++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 51 ++++++++++++++++++++++ 5 files changed, 155 insertions(+) diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index fdb79f72fc7..e647704158f 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -119,6 +119,7 @@ DataFrame DataFrame.withColumn DataFrame.withColumns DataFrame.withColumnRenamed + DataFrame.withColumnsRenamed DataFrame.withMetadata DataFrame.withWatermark DataFrame.write diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 23dfd4e7ec8..7c3cc92d393 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -4430,6 +4430,46 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sparkSession) + def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": + """ + Returns a new :class:`DataFrame` by renaming multiple columns. + This is a no-op if schema doesn't contain the given column names. + + .. versionadded:: 3.4.0 + Added support for multiple columns renaming + + Parameters + ---------- + colsMap : dict + a dict of existing column names and corresponding desired column names. + Currently, only single map is supported. + + Returns + ------- + :class:`DataFrame` + DataFrame with renamed columns. + + See Also + -------- + :meth:`withColumnRenamed` + + Examples + -------- + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}) + >>> df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show() + +---+-----+----+----+ + |age| name|age4|age5| + +---+-----+----+----+ + | 2|Alice| 4| 5| + | 5| Bob| 7| 8| + +---+-----+----+----+ + """ + if not isinstance(colsMap, dict): + raise TypeError("colsMap must be dict of existing column name and new column name.") + + return DataFrame(self._jdf.withColumnsRenamed(colsMap), self.sparkSession) + def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index d15ba442ab4..be5784114fb 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -97,6 +97,22 @@ class DataFrameTests(ReusedSQLTestCase): self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) + def test_with_columns_renamed(self): + df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) + + # rename both columns + renamed_df1 = df.withColumnsRenamed({"name": "naam", "age": "leeftijd"}) + self.assertEqual(renamed_df1.columns, ["naam", "leeftijd"]) + + # rename one column with one missing name + renamed_df2 = df.withColumnsRenamed({"name": "naam", "address": "adres"}) + self.assertEqual(renamed_df2.columns, ["naam", "age"]) + + # negative test for incorrect type + type_error_msg = "colsMap must be dict of existing column name and new column name." + with self.assertRaisesRegex(TypeError, type_error_msg): + df.withColumnsRenamed(("name", "x")) + def test_drop_duplicates(self): # SPARK-36034 test that drop duplicates throws a type error when in correct type provided df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 18aea40f556..6a07db71428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2808,6 +2808,53 @@ class Dataset[T] private[sql]( } } + /** + * (Scala-specific) + * Returns a new Dataset with a columns renamed. + * This is a no-op if schema doesn't contain existingName. + * + * `colsMap` is a map of existing column name and new column name. + * + * @throws AnalysisException if there are duplicate names in resulting projection + * + * @group untypedrel + * @since 3.4.0 + */ + @throws[AnalysisException] + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { + val resolver = sparkSession.sessionState.analyzer.resolver + val output: Seq[NamedExpression] = queryExecution.analyzed.output + + val projectList = colsMap.foldLeft(output) { + case (attrs, (existingName, newName)) => + attrs.map(attr => + if (resolver(attr.name, existingName)) { + Alias(attr, newName)() + } else { + attr + } + ) + } + SchemaUtils.checkColumnNameDuplication( + projectList.map(_.name), + "in given column names for withColumnsRenamed", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + withPlan(Project(projectList, logicalPlan)) + } + + /** + * (Java-specific) + * Returns a new Dataset with a columns renamed. + * This is a no-op if schema doesn't contain existingName. + * + * `colsMap` is a map of existing column name and new column name. + * + * @group untypedrel + * @since 3.4.0 + */ + def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = + withColumnsRenamed(colsMap.asScala.toMap) + /** * Returns a new Dataset by updating an existing column with metadata. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b29b5c2b341..0fcbbe6fa69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -895,6 +895,57 @@ class DataFrameSuite extends QueryTest assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } + test("SPARK-40311: withColumnsRenamed") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL2"), + Seq(col("key") + 1, col("key") + 2)) + .withColumnsRenamed(Map("newCol1" -> "renamed1", "newCol2" -> "renamed2")) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.columns === Array("key", "value", "renamed1", "renamed2")) + } + + test("SPARK-40311: withColumnsRenamed case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL2"), + Seq(col("key") + 1, col("key") + 2)) + .withColumnsRenamed(Map("newCol1" -> "renamed1", "newCol2" -> "renamed2")) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.columns === Array("key", "value", "renamed1", "newCOL2")) + } + } + + test("SPARK-40311: withColumnsRenamed duplicate column names simple") { + val e = intercept[AnalysisException] { + person.withColumnsRenamed(Map("id" -> "renamed", "name" -> "renamed")) + } + assert(e.getMessage.contains("Found duplicate column(s)")) + assert(e.getMessage.contains("in given column names for withColumnsRenamed:")) + assert(e.getMessage.contains("`renamed`")) + } + + test("SPARK-40311: withColumnsRenamed duplicate column names simple case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = person.withColumnsRenamed(Map("id" -> "renamed", "name" -> "Renamed")) + assert(df.columns === Array("renamed", "Renamed", "age")) + } + } + + test("SPARK-40311: withColumnsRenamed duplicate column names indirect") { + val e = intercept[AnalysisException] { + person.withColumnsRenamed(Map("id" -> "renamed1", "renamed1" -> "age")) + } + assert(e.getMessage.contains("Found duplicate column(s)")) + assert(e.getMessage.contains("in given column names for withColumnsRenamed:")) + assert(e.getMessage.contains("`age`")) + } + test("SPARK-20384: Value class filter") { val df = spark.sparkContext .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org