This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 0b7385aa82b [SPARK-42367][CONNECT][PYTHON] DataFrame.drop` should handle duplicated columns properly 0b7385aa82b is described below commit 0b7385aa82b1d08ee89d96f9241701b845c6b5f5 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Feb 28 11:25:56 2023 +0800 [SPARK-42367][CONNECT][PYTHON] DataFrame.drop` should handle duplicated columns properly ### What changes were proposed in this pull request? match https://github.com/apache/spark/pull/40135 ### Why are the changes needed? `DataFrame.drop` should handle duplicated columns properly. we can not always convert column names to columns when there are multi columns with the same name. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? enabled tests Closes #40013 from zhengruifeng/connect_drop_duplicate. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit 3d900b70c5593326ddc96f094d9abe796308b0e4) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 13 ++- .../main/protobuf/spark/connect/relations.proto | 9 +- .../query-tests/queries/drop_multiple_strings.json | 14 +-- .../queries/drop_multiple_strings.proto.bin | Bin 71 -> 59 bytes .../query-tests/queries/drop_single_string.json | 6 +- .../queries/drop_single_string.proto.bin | Bin 56 -> 52 bytes .../resources/query-tests/queries/unionByName.json | 12 +- .../query-tests/queries/unionByName.proto.bin | Bin 143 -> 135 bytes .../org/apache/spark/sql/connect/dsl/package.scala | 10 +- .../sql/connect/planner/SparkConnectPlanner.scala | 18 +-- python/pyspark/sql/connect/dataframe.py | 3 - python/pyspark/sql/connect/plan.py | 18 +-- python/pyspark/sql/connect/proto/relations_pb2.py | 126 ++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 23 ++-- python/pyspark/sql/dataframe.py | 4 +- .../pyspark/sql/tests/connect/test_connect_plan.py | 10 +- .../sql/tests/connect/test_parity_dataframe.py | 5 - 17 files changed, 122 insertions(+), 149 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 1015d61a9c2..c4f54e493ee 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2099,7 +2099,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def drop(colName: String): DataFrame = { - drop(functions.col(colName)) + drop(Seq(colName): _*) } /** @@ -2113,7 +2113,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ @scala.annotation.varargs - def drop(colNames: String*): DataFrame = buildDrop(colNames.map(functions.col)) + def drop(colNames: String*): DataFrame = buildDropByNames(colNames) /** * Returns a new Dataset with column dropped. @@ -2144,7 +2144,14 @@ class Dataset[T] private[sql] ( private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder => builder.getDropBuilder .setInput(plan.getRoot) - .addAllCols(cols.map(_.expr).asJava) + .addAllColumns(cols.map(_.expr).asJava) + } + + private def buildDropByNames(cols: Seq[String]): DataFrame = sparkSession.newDataFrame { + builder => + builder.getDropBuilder + .setInput(plan.getRoot) + .addAllColumnNames(cols.asJava) } /** diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 2221b4e3982..3e4a4daeb36 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -321,10 +321,11 @@ message Drop { // (Required) The input relation. Relation input = 1; - // (Required) columns to drop. - // - // Should contain at least 1 item. - repeated Expression cols = 2; + // (Optional) columns to drop. + repeated Expression columns = 2; + + // (Optional) names of columns to drop. + repeated string column_names = 3; } diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.json index 69c31ec5858..dcda09236f4 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.json @@ -11,18 +11,6 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "cols": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "id" - } - }, { - "unresolvedAttribute": { - "unparsedIdentifier": "a" - } - }, { - "unresolvedAttribute": { - "unparsedIdentifier": "b" - } - }] + "columnNames": ["id", "a", "b"] } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.proto.bin index c085f69fc54..e5be859b708 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_strings.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.json index 7e4c4e8feb1..8f849d0346d 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.json @@ -11,10 +11,6 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "cols": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "a" - } - }] + "columnNames": ["a"] } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.proto.bin index 9d704de48c8..12013543c46 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_string.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unionByName.json b/connector/connect/common/src/test/resources/query-tests/queries/unionByName.json index 9244eb08790..181d681b7f1 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/unionByName.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/unionByName.json @@ -16,11 +16,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "cols": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "b" - } - }] + "columnNames": ["b"] } }, "rightInput": { @@ -36,11 +32,7 @@ "schema": "struct\u003ca:int,id:bigint,payload:binary\u003e" } }, - "cols": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "payload" - } - }] + "columnNames": ["payload"] } }, "setOpType": "SET_OP_TYPE_UNION", diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unionByName.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/unionByName.proto.bin index 64d9fb901d2..519fbc8edaa 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/unionByName.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/unionByName.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 4c1fbb877f4..840f43abf49 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -703,21 +703,13 @@ package object dsl { def drop(columns: String*): Relation = { assert(columns.nonEmpty) - val cols = columns.map(col => - Expression.newBuilder - .setUnresolvedAttribute( - Expression.UnresolvedAttribute.newBuilder - .setUnparsedIdentifier(col) - .build()) - .build()) - Relation .newBuilder() .setDrop( Drop .newBuilder() .setInput(logicalPlan) - .addAllCols(cols.asJava) + .addAllColumnNames(columns.toSeq.asJava) .build()) .build() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 887379ab80d..1925a41b916 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1361,16 +1361,16 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformDrop(rel: proto.Drop): LogicalPlan = { - assert(rel.getColsCount > 0, s"cols must contains at least 1 item!") - - val cols = rel.getColsList.asScala.toArray.map { expr => - Column(transformExpression(expr)) + var output = Dataset.ofRows(session, transformRelation(rel.getInput)) + if (rel.getColumnsCount > 0) { + val cols = rel.getColumnsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + output = output.drop(cols.head, cols.tail: _*) } - - Dataset - .ofRows(session, transformRelation(rel.getInput)) - .drop(cols.head, cols.tail: _*) - .logicalPlan + if (rel.getColumnNamesCount > 0) { + val colNames = rel.getColumnNamesList.asScala.toSeq + output = output.drop(colNames: _*) + } + output.logicalPlan } private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b2253c21b66..0d501b0bc4d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1716,9 +1716,6 @@ def _test() -> None: del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__ del pyspark.sql.connect.dataframe.DataFrame.repartitionByRange.__doc__ - # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns - del pyspark.sql.connect.dataframe.DataFrame.drop.__doc__ - # TODO(SPARK-41625): Support Structured Streaming del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 857cca64c6f..f82cf9167cb 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -607,23 +607,17 @@ class Drop(LogicalPlan): ) -> None: super().__init__(child) assert len(columns) > 0 and all(isinstance(c, (Column, str)) for c in columns) - self.columns = columns - - def _convert_to_expr( - self, col: Union[Column, str], session: "SparkConnectClient" - ) -> proto.Expression: - expr = proto.Expression() - if isinstance(col, Column): - expr.CopyFrom(col.to_plan(session)) - else: - expr.CopyFrom(self.unresolved_attr(col)) - return expr + self._columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.drop.input.CopyFrom(self._child.plan(session)) - plan.drop.cols.extend([self._convert_to_expr(c, session) for c in self.columns]) + for c in self._columns: + if isinstance(c, Column): + plan.drop.columns.append(c.to_plan(session)) + else: + plan.drop.column_names.append(c) return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index c6d9616e44c..3b4573177a2 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -690,66 +690,66 @@ if _descriptor._USE_C_DESCRIPTORS == False: _AGGREGATE_GROUPTYPE._serialized_end = 5236 _SORT._serialized_start = 5239 _SORT._serialized_end = 5399 - _DROP._serialized_start = 5401 - _DROP._serialized_end = 5501 - _DEDUPLICATE._serialized_start = 5504 - _DEDUPLICATE._serialized_end = 5675 - _LOCALRELATION._serialized_start = 5677 - _LOCALRELATION._serialized_end = 5766 - _SAMPLE._serialized_start = 5769 - _SAMPLE._serialized_end = 6042 - _RANGE._serialized_start = 6045 - _RANGE._serialized_end = 6190 - _SUBQUERYALIAS._serialized_start = 6192 - _SUBQUERYALIAS._serialized_end = 6306 - _REPARTITION._serialized_start = 6309 - _REPARTITION._serialized_end = 6451 - _SHOWSTRING._serialized_start = 6454 - _SHOWSTRING._serialized_end = 6596 - _STATSUMMARY._serialized_start = 6598 - _STATSUMMARY._serialized_end = 6690 - _STATDESCRIBE._serialized_start = 6692 - _STATDESCRIBE._serialized_end = 6773 - _STATCROSSTAB._serialized_start = 6775 - _STATCROSSTAB._serialized_end = 6876 - _STATCOV._serialized_start = 6878 - _STATCOV._serialized_end = 6974 - _STATCORR._serialized_start = 6977 - _STATCORR._serialized_end = 7114 - _STATAPPROXQUANTILE._serialized_start = 7117 - _STATAPPROXQUANTILE._serialized_end = 7281 - _STATFREQITEMS._serialized_start = 7283 - _STATFREQITEMS._serialized_end = 7408 - _STATSAMPLEBY._serialized_start = 7411 - _STATSAMPLEBY._serialized_end = 7720 - _STATSAMPLEBY_FRACTION._serialized_start = 7612 - _STATSAMPLEBY_FRACTION._serialized_end = 7711 - _NAFILL._serialized_start = 7723 - _NAFILL._serialized_end = 7857 - _NADROP._serialized_start = 7860 - _NADROP._serialized_end = 7994 - _NAREPLACE._serialized_start = 7997 - _NAREPLACE._serialized_end = 8293 - _NAREPLACE_REPLACEMENT._serialized_start = 8152 - _NAREPLACE_REPLACEMENT._serialized_end = 8293 - _TODF._serialized_start = 8295 - _TODF._serialized_end = 8383 - _WITHCOLUMNSRENAMED._serialized_start = 8386 - _WITHCOLUMNSRENAMED._serialized_end = 8625 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8558 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8625 - _WITHCOLUMNS._serialized_start = 8627 - _WITHCOLUMNS._serialized_end = 8746 - _HINT._serialized_start = 8749 - _HINT._serialized_end = 8881 - _UNPIVOT._serialized_start = 8884 - _UNPIVOT._serialized_end = 9211 - _UNPIVOT_VALUES._serialized_start = 9141 - _UNPIVOT_VALUES._serialized_end = 9200 - _TOSCHEMA._serialized_start = 9213 - _TOSCHEMA._serialized_end = 9319 - _REPARTITIONBYEXPRESSION._serialized_start = 9322 - _REPARTITIONBYEXPRESSION._serialized_end = 9525 - _FRAMEMAP._serialized_start = 9527 - _FRAMEMAP._serialized_end = 9652 + _DROP._serialized_start = 5402 + _DROP._serialized_end = 5543 + _DEDUPLICATE._serialized_start = 5546 + _DEDUPLICATE._serialized_end = 5717 + _LOCALRELATION._serialized_start = 5719 + _LOCALRELATION._serialized_end = 5808 + _SAMPLE._serialized_start = 5811 + _SAMPLE._serialized_end = 6084 + _RANGE._serialized_start = 6087 + _RANGE._serialized_end = 6232 + _SUBQUERYALIAS._serialized_start = 6234 + _SUBQUERYALIAS._serialized_end = 6348 + _REPARTITION._serialized_start = 6351 + _REPARTITION._serialized_end = 6493 + _SHOWSTRING._serialized_start = 6496 + _SHOWSTRING._serialized_end = 6638 + _STATSUMMARY._serialized_start = 6640 + _STATSUMMARY._serialized_end = 6732 + _STATDESCRIBE._serialized_start = 6734 + _STATDESCRIBE._serialized_end = 6815 + _STATCROSSTAB._serialized_start = 6817 + _STATCROSSTAB._serialized_end = 6918 + _STATCOV._serialized_start = 6920 + _STATCOV._serialized_end = 7016 + _STATCORR._serialized_start = 7019 + _STATCORR._serialized_end = 7156 + _STATAPPROXQUANTILE._serialized_start = 7159 + _STATAPPROXQUANTILE._serialized_end = 7323 + _STATFREQITEMS._serialized_start = 7325 + _STATFREQITEMS._serialized_end = 7450 + _STATSAMPLEBY._serialized_start = 7453 + _STATSAMPLEBY._serialized_end = 7762 + _STATSAMPLEBY_FRACTION._serialized_start = 7654 + _STATSAMPLEBY_FRACTION._serialized_end = 7753 + _NAFILL._serialized_start = 7765 + _NAFILL._serialized_end = 7899 + _NADROP._serialized_start = 7902 + _NADROP._serialized_end = 8036 + _NAREPLACE._serialized_start = 8039 + _NAREPLACE._serialized_end = 8335 + _NAREPLACE_REPLACEMENT._serialized_start = 8194 + _NAREPLACE_REPLACEMENT._serialized_end = 8335 + _TODF._serialized_start = 8337 + _TODF._serialized_end = 8425 + _WITHCOLUMNSRENAMED._serialized_start = 8428 + _WITHCOLUMNSRENAMED._serialized_end = 8667 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8600 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8667 + _WITHCOLUMNS._serialized_start = 8669 + _WITHCOLUMNS._serialized_end = 8788 + _HINT._serialized_start = 8791 + _HINT._serialized_end = 8923 + _UNPIVOT._serialized_start = 8926 + _UNPIVOT._serialized_end = 9253 + _UNPIVOT_VALUES._serialized_start = 9183 + _UNPIVOT_VALUES._serialized_end = 9242 + _TOSCHEMA._serialized_start = 9255 + _TOSCHEMA._serialized_end = 9361 + _REPARTITIONBYEXPRESSION._serialized_start = 9364 + _REPARTITIONBYEXPRESSION._serialized_end = 9567 + _FRAMEMAP._serialized_start = 9569 + _FRAMEMAP._serialized_end = 9694 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 27fd07a192e..b60fd5a1a61 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1266,32 +1266,39 @@ class Drop(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor INPUT_FIELD_NUMBER: builtins.int - COLS_FIELD_NUMBER: builtins.int + COLUMNS_FIELD_NUMBER: builtins.int + COLUMN_NAMES_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) The input relation.""" @property - def cols( + def columns( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ pyspark.sql.connect.proto.expressions_pb2.Expression ]: - """(Required) columns to drop. - - Should contain at least 1 item. - """ + """(Optional) columns to drop.""" + @property + def column_names( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Optional) names of columns to drop.""" def __init__( self, *, input: global___Relation | None = ..., - cols: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + columns: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] | None = ..., + column_names: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal["input", b"input"] ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["cols", b"cols", "input", b"input"] + self, + field_name: typing_extensions.Literal[ + "column_names", b"column_names", "columns", b"columns", "input", b"input" + ], ) -> None: ... global___Drop = Drop diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1cd28f0e8b2..a6357a7c137 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -4915,12 +4915,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): Drop the column that joined both DataFrames on. - >>> df.join(df2, df.name == df2.name, 'inner').drop('name').show() + >>> df.join(df2, df.name == df2.name, 'inner').drop('name').sort('age').show() +---+------+ |age|height| +---+------+ - | 16| 85| | 14| 80| + | 16| 85| +---+------+ """ column_names: List[str] = [] diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index a152ae0e8c3..2de51189c4d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -443,14 +443,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture): plan = df.filter(df.col_name > 3).drop("col_a", "col_b")._plan.to_proto(self.connect) self.assertEqual( - [f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.cols], + plan.root.drop.column_names, ["col_a", "col_b"], ) plan = df.filter(df.col_name > 3).drop(df.col_x, "col_b")._plan.to_proto(self.connect) self.assertEqual( - [f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.cols], - ["col_x", "col_b"], + [f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.columns], + ["col_x"], + ) + self.assertEqual( + plan.root.drop.column_names, + ["col_b"], ) def test_deduplicate(self): diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 97c0f473ce8..da1172086cc 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -118,11 +118,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_to_pandas_with_duplicated_column_names(self): self.check_to_pandas_with_duplicated_column_names() - # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns properly - @unittest.skip("Fails in Spark Connect, should enable.") - def test_drop_duplicates_with_ambiguous_reference(self): - super().test_drop_duplicates_with_ambiguous_reference() - if __name__ == "__main__": import unittest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org