This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0b9ed26e482 [SPARK-42444][PYTHON] `DataFrame.drop` should handle duplicated columns properly 0b9ed26e482 is described below commit 0b9ed26e48248aa58642b3626a02dd8c89a01afb Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Feb 24 08:03:06 2023 +0800 [SPARK-42444][PYTHON] `DataFrame.drop` should handle duplicated columns properly ### What changes were proposed in this pull request? Existing implementation always convert inputs (maybe column or column name) to columns, this cause `AMBIGUOUS_REFERENCE` issue since there maybe several columns with the same name. In the JVM side, the logics of drop(column: Column) and drop(columnName: String) are different, we can not simply always convert a column name to column via col() method. When there are multi-column with the same name (e.g, `name`), users can: 1, `drop('name')` --- drop all the columns; 2, `drop(df1.name)` --- drop the column from the specific dataframe `df1`; But if users call `drop(col('name'))`, it will fail due to ambiguous issue. In Pyspark, it is a bit complex, that the user can input both column names with columns. This PR drops the columns first, and then the column names. ### Why are the changes needed? bug fix ``` >>> from pyspark.sql import Row >>> df1 = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df3 = df1.join(df2, df1.name == df2.name, 'inner') >>> df3.show() +---+----+------+----+ |age|name|height|name| +---+----+------+----+ | 16| Bob| 85| Bob| | 14| Tom| 80| Tom| +---+----+------+----+ ``` BEFORE ``` >>> df3.drop("name", "age").columns Traceback (most recent call last): ... pyspark.errors.exceptions.captured.AnalysisException: [AMBIGUOUS_REFERENCE] Reference `name` is ambiguous, could be: [`name`, `name`]. ``` AFTER ``` >>> df3.drop("name", "age").columns ['height'] ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests Closes #40135 from zhengruifeng/py_fix_drop. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/dataframe.py | 27 +++++++++++++--------- .../sql/tests/connect/test_parity_dataframe.py | 5 ++++ python/pyspark/sql/tests/test_dataframe.py | 9 ++++++++ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fa25d148060..1cd28f0e8b2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -4923,21 +4923,26 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | 14| 80| +---+------+ """ - if len(cols) == 1: - col = cols[0] - if isinstance(col, str): - jdf = self._jdf.drop(col) - elif isinstance(col, Column): - jdf = self._jdf.drop(col._jc) + column_names: List[str] = [] + java_columns: List[JavaObject] = [] + + for c in cols: + if isinstance(c, str): + column_names.append(c) + elif isinstance(c, Column): + java_columns.append(c._jc) else: raise PySparkTypeError( error_class="NOT_COLUMN_OR_STR", - message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, + message_parameters={"arg_name": "col", "arg_type": type(c).__name__}, ) - else: - jcols = [_to_java_column(c) for c in cols] - first_column, *remaining_columns = jcols - jdf = self._jdf.drop(first_column, self._jseq(remaining_columns)) + + jdf = self._jdf + if len(java_columns) > 0: + first_column, *remaining_columns = java_columns + jdf = jdf.drop(first_column, self._jseq(remaining_columns)) + if len(column_names) > 0: + jdf = jdf.drop(self._jseq(column_names)) return DataFrame(jdf, self.sparkSession) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 07cae0fb27d..25fdbebd991 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -142,6 +142,11 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_to_pandas_with_duplicated_column_names(self): super().test_to_pandas_with_duplicated_column_names() + # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns properly + @unittest.skip("Fails in Spark Connect, should enable.") + def test_drop_duplicates_with_ambiguous_reference(self): + super().test_drop_duplicates_with_ambiguous_reference() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 1d52602a96f..610edc0926d 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -144,6 +144,15 @@ class DataFrameTestsMixin: message_parameters={"arg_name": "subset", "arg_type": "str"}, ) + def test_drop_duplicates_with_ambiguous_reference(self): + df1 = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + df2 = self.spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) + df3 = df1.join(df2, df1.name == df2.name, "inner") + + self.assertEqual(df3.drop("name", "age").columns, ["height"]) + self.assertEqual(df3.drop("name", df3.age, "unknown").columns, ["height"]) + self.assertEqual(df3.drop("name", "age", df3.height).columns, []) + def test_dropna(self): schema = StructType( [ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org