This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new d9f0c44e7f24 [SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan DataFrameDropColumns for Dataframe.drop d9f0c44e7f24 is described below commit d9f0c44e7f24cba95f7bf1737bb52ff73a7b9094 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Nov 14 12:09:37 2023 +0800 [SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan DataFrameDropColumns for Dataframe.drop ### What changes were proposed in this pull request? backport https://github.com/apache/spark/pull/43683 to 3.5 ### Why are the changes needed? to fix a connect bug ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43776 from zhengruifeng/sql_drop_plan_35. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_dataframe.py | 37 ++++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/ResolveDataFrameDropColumns.scala | 49 ++++++++++++++++++++++ .../plans/logical/basicLogicalOperators.scala | 14 +++++++ .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../main/scala/org/apache/spark/sql/Dataset.scala | 15 +------ 6 files changed, 104 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 33049233dee9..5907c8c09fb4 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -106,6 +106,43 @@ class DataFrameTestsMixin: self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) + def test_drop_join(self): + left_df = self.spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "c")], + ["join_key", "value1"], + ) + right_df = self.spark.createDataFrame( + [(1, "aa"), (2, "bb"), (4, "dd")], + ["join_key", "value2"], + ) + joined_df = left_df.join( + right_df, + on=left_df["join_key"] == right_df["join_key"], + how="left", + ) + + dropped_1 = joined_df.drop(left_df["join_key"]) + self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"]) + self.assertEqual( + dropped_1.sort("value1").collect(), + [ + Row(value1="a", join_key=1, value2="aa"), + Row(value1="b", join_key=2, value2="bb"), + Row(value1="c", join_key=None, value2=None), + ], + ) + + dropped_2 = joined_df.drop(right_df["join_key"]) + self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"]) + self.assertEqual( + dropped_2.sort("value1").collect(), + [ + Row(join_key=1, value1="a", value2="aa"), + Row(join_key=2, value1="b", value2="bb"), + Row(join_key=3, value1="c", value2=None), + ], + ) + def test_with_columns_renamed(self): df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8e3c9b30c61b..80cb5d8c6087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -307,6 +307,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: + new ResolveDataFrameDropColumns(catalogManager) :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala new file mode 100644 index 000000000000..2642b4a1c5da --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS +import org.apache.spark.sql.connector.catalog.CatalogManager + +/** + * A rule that rewrites DataFrameDropColumns to Project. + * Note that DataFrameDropColumns allows and ignores non-existing columns. + */ +class ResolveDataFrameDropColumns(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with ColumnResolutionHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(DF_DROP_COLUMNS)) { + case d: DataFrameDropColumns if d.childrenResolved => + // expressions in dropList can be unresolved, e.g. + // df.drop(col("non-existing-column")) + val dropped = d.dropList.map { + case u: UnresolvedAttribute => + resolveExpressionByPlanChildren(u, d.child) + case e => e + } + val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr))) + if (remaining.size == d.child.output.size) { + d.child + } else { + Project(remaining, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 96b67fc52e0d..0e460706fc5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -235,6 +235,20 @@ object Project { } } +case class DataFrameDropColumns(dropList: Seq[Expression], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = Nil + + override def maxRows: Option[Long] = child.maxRows + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition + + final override val nodePatterns: Seq[TreePattern] = Seq(DF_DROP_COLUMNS) + + override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: LogicalPlan): DataFrameDropColumns = + copy(child = newChild) +} + /** * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b806ebbed52d..bf7b2db1719f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -105,6 +105,7 @@ object TreePattern extends Enumeration { val AS_OF_JOIN: Value = Value val COMMAND: Value = Value val CTE: Value = Value + val DF_DROP_COLUMNS: Value = Value val DISTINCT_LIKE: Value = Value val EVAL_PYTHON_UDF: Value = Value val EVAL_PYTHON_UDTF: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e047b927b905..f53c6ddaa388 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3013,19 +3013,8 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): DataFrame = { - val allColumns = col +: cols - val expressions = (for (col <- allColumns) yield col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( - u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) - case Column(expr: Expression) => expr - }) - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - expressions.forall(expression => !attr.semanticEquals(expression)) - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + def drop(col: Column, cols: Column*): DataFrame = withPlan { + DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan) } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org