This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new d9f0c44e7f24 [SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan 
DataFrameDropColumns for Dataframe.drop
d9f0c44e7f24 is described below

commit d9f0c44e7f24cba95f7bf1737bb52ff73a7b9094
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Nov 14 12:09:37 2023 +0800

    [SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan 
DataFrameDropColumns for Dataframe.drop
    
    ### What changes were proposed in this pull request?
    backport https://github.com/apache/spark/pull/43683 to 3.5
    
    ### Why are the changes needed?
    to fix a connect bug
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43776 from zhengruifeng/sql_drop_plan_35.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_dataframe.py         | 37 ++++++++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  1 +
 .../analysis/ResolveDataFrameDropColumns.scala     | 49 ++++++++++++++++++++++
 .../plans/logical/basicLogicalOperators.scala      | 14 +++++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 +
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 15 +------
 6 files changed, 104 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 33049233dee9..5907c8c09fb4 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -106,6 +106,43 @@ class DataFrameTestsMixin:
         self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
         self.assertEqual(df.drop(col("name"), col("age"), 
col("random")).columns, ["active"])
 
+    def test_drop_join(self):
+        left_df = self.spark.createDataFrame(
+            [(1, "a"), (2, "b"), (3, "c")],
+            ["join_key", "value1"],
+        )
+        right_df = self.spark.createDataFrame(
+            [(1, "aa"), (2, "bb"), (4, "dd")],
+            ["join_key", "value2"],
+        )
+        joined_df = left_df.join(
+            right_df,
+            on=left_df["join_key"] == right_df["join_key"],
+            how="left",
+        )
+
+        dropped_1 = joined_df.drop(left_df["join_key"])
+        self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"])
+        self.assertEqual(
+            dropped_1.sort("value1").collect(),
+            [
+                Row(value1="a", join_key=1, value2="aa"),
+                Row(value1="b", join_key=2, value2="bb"),
+                Row(value1="c", join_key=None, value2=None),
+            ],
+        )
+
+        dropped_2 = joined_df.drop(right_df["join_key"])
+        self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"])
+        self.assertEqual(
+            dropped_2.sort("value1").collect(),
+            [
+                Row(join_key=1, value1="a", value2="aa"),
+                Row(join_key=2, value1="b", value2="bb"),
+                Row(join_key=3, value1="c", value2=None),
+            ],
+        )
+
     def test_with_columns_renamed(self):
         df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], 
["name", "age"])
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8e3c9b30c61b..80cb5d8c6087 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -307,6 +307,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       ResolveWindowFrame ::
       ResolveNaturalAndUsingJoin ::
       ResolveOutputRelation ::
+      new ResolveDataFrameDropColumns(catalogManager) ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
       ResolveAggregateFunctions ::
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
new file mode 100644
index 000000000000..2642b4a1c5da
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, 
LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS
+import org.apache.spark.sql.connector.catalog.CatalogManager
+
+/**
+ * A rule that rewrites DataFrameDropColumns to Project.
+ * Note that DataFrameDropColumns allows and ignores non-existing columns.
+ */
+class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
+  extends Rule[LogicalPlan] with ColumnResolutionHelper  {
+
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsWithPruning(
+    _.containsPattern(DF_DROP_COLUMNS)) {
+    case d: DataFrameDropColumns if d.childrenResolved =>
+      // expressions in dropList can be unresolved, e.g.
+      //   df.drop(col("non-existing-column"))
+      val dropped = d.dropList.map {
+        case u: UnresolvedAttribute =>
+          resolveExpressionByPlanChildren(u, d.child)
+        case e => e
+      }
+      val remaining = d.child.output.filterNot(attr => 
dropped.exists(_.semanticEquals(attr)))
+      if (remaining.size == d.child.output.size) {
+        d.child
+      } else {
+        Project(remaining, d.child)
+      }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 96b67fc52e0d..0e460706fc5b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -235,6 +235,20 @@ object Project {
   }
 }
 
+case class DataFrameDropColumns(dropList: Seq[Expression], child: LogicalPlan) 
extends UnaryNode {
+  override def output: Seq[Attribute] = Nil
+
+  override def maxRows: Option[Long] = child.maxRows
+  override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(DF_DROP_COLUMNS)
+
+  override lazy val resolved: Boolean = false
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
DataFrameDropColumns =
+    copy(child = newChild)
+}
+
 /**
  * Applies a [[Generator]] to a stream of input rows, combining the
  * output of each into a new stream of rows.  This operation is similar to a 
`flatMap` in functional
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b806ebbed52d..bf7b2db1719f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -105,6 +105,7 @@ object TreePattern extends Enumeration  {
   val AS_OF_JOIN: Value = Value
   val COMMAND: Value = Value
   val CTE: Value = Value
+  val DF_DROP_COLUMNS: Value = Value
   val DISTINCT_LIKE: Value = Value
   val EVAL_PYTHON_UDF: Value = Value
   val EVAL_PYTHON_UDTF: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e047b927b905..f53c6ddaa388 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3013,19 +3013,8 @@ class Dataset[T] private[sql](
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def drop(col: Column, cols: Column*): DataFrame = {
-    val allColumns = col +: cols
-    val expressions = (for (col <- allColumns) yield col match {
-      case Column(u: UnresolvedAttribute) =>
-        queryExecution.analyzed.resolveQuoted(
-          u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
-      case Column(expr: Expression) => expr
-    })
-    val attrs = this.logicalPlan.output
-    val colsAfterDrop = attrs.filter { attr =>
-      expressions.forall(expression => !attr.semanticEquals(expression))
-    }.map(attr => Column(attr))
-    select(colsAfterDrop : _*)
+  def drop(col: Column, cols: Column*): DataFrame = withPlan {
+    DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan)
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to