This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fddf25a4dd8 [SPARK-43302][SQL][FOLLOWUP] Code cleanup for PythonUDAF
fddf25a4dd8 is described below

commit fddf25a4dd8029db89287416de39adb27e8643c8
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed May 17 13:39:27 2023 +0800

    [SPARK-43302][SQL][FOLLOWUP] Code cleanup for PythonUDAF
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/40739 to do some 
code cleanup
    1. remove the pattern `PYTHON_UDAF` as it's not used by any rule.
    2. add `PythonFuncExpression.evalType` for convenience: catalyst rules 
(including third-party extensions) may want to get the eval type of a python 
function, no matter it's UDF or UDAF.
    3. update the python profile to use `PythonUDAF.resultId` instead of 
`AggregateExpression.resultId`, to be consistent with `PythonUDF`
    
    ### Why are the changes needed?
    
    code cleanup
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #41142 from cloud-fan/follow.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 python/pyspark/sql/column.py                                 |  4 ++++
 python/pyspark/sql/udf.py                                    | 12 +++++++-----
 .../apache/spark/sql/catalyst/expressions/PythonUDF.scala    | 11 ++++++-----
 .../org/apache/spark/sql/catalyst/trees/TreePatterns.scala   |  1 -
 .../sql/execution/python/UserDefinedPythonFunction.scala     | 11 +++++++++--
 5 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 49a42406048..3cf59989965 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -69,6 +69,10 @@ def _to_java_column(col: "ColumnOrName") -> JavaObject:
     return jcol
 
 
+def _to_java_expr(col: "ColumnOrName") -> JavaObject:
+    return _to_java_column(col).expr()
+
+
 def _to_seq(
     sc: SparkContext,
     cols: Iterable["ColumnOrName"],
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 45828187295..87d53266edf 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -30,7 +30,7 @@ from py4j.java_gateway import JavaObject
 from pyspark import SparkContext
 from pyspark.profiler import Profiler
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
-from pyspark.sql.column import Column, _to_java_column, _to_seq
+from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
 from pyspark.sql.types import (
     ArrayType,
     BinaryType,
@@ -419,8 +419,9 @@ class UserDefinedFunction:
 
                 func.__signature__ = inspect.signature(f)  # type: 
ignore[attr-defined]
                 judf = self._create_judf(func)
-                jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
-                id = jPythonUDF.expr().resultId().id()
+                jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+                jPythonUDF = judf.fromUDFExpr(jUDFExpr)
+                id = jUDFExpr.resultId().id()
                 sc.profiler_collector.add_profiler(id, profiler)
             else:  # memory_profiler_enabled
                 f = self.func
@@ -436,8 +437,9 @@ class UserDefinedFunction:
 
                 func.__signature__ = inspect.signature(f)  # type: 
ignore[attr-defined]
                 judf = self._create_judf(func)
-                jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
-                id = jPythonUDF.expr().resultId().id()
+                jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+                jPythonUDF = judf.fromUDFExpr(jUDFExpr)
+                id = jUDFExpr.resultId().id()
                 sc.profiler_collector.add_profiler(id, memory_profiler)
         else:
             judf = self._judf
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 08ffbea5510..8636eb61034 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -22,7 +22,7 @@ import org.apache.spark.api.python.{PythonEvalType, 
PythonFunction}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDAF, 
PYTHON_UDF, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, 
TreePattern}
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -53,9 +53,12 @@ object PythonUDF {
 trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression 
{ self: Expression =>
   def name: String
   def func: PythonFunction
+  def evalType: Int
   def udfDeterministic: Boolean
   def resultId: ExprId
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
+
   override lazy val deterministic: Boolean = udfDeterministic && 
children.forall(_.deterministic)
 
   override def toString: String = s"$name(${children.mkString(", 
")})#${resultId.id}$typeSuffix"
@@ -80,8 +83,6 @@ case class PythonUDF(
   lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), 
dataType, nullable)(
     exprId = resultId)
 
-  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
-
   override lazy val canonicalized: Expression = {
     val canonicalizedChildren = children.map(_.canonicalized)
     // `resultId` can be seen as cosmetic variation in PythonUDF, as it 
doesn't affect the result.
@@ -119,6 +120,8 @@ case class PythonUDAF(
     resultId: ExprId = NamedExpression.newExprId)
   extends UnevaluableAggregateFunc with PythonFuncExpression {
 
+  override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
+
   override def sql(isDistinct: Boolean): String = {
     val distinct = if (isDistinct) "DISTINCT " else ""
     s"$name($distinct${children.mkString(", ")})"
@@ -129,8 +132,6 @@ case class PythonUDAF(
     name + children.mkString(start, ", ", ")") + s"#${resultId.id}$typeSuffix"
   }
 
-  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDAF)
-
   override lazy val canonicalized: Expression = {
     val canonicalizedChildren = children.map(_.canonicalized)
     // `resultId` can be seen as cosmetic variation in PythonUDAF, as it 
doesn't affect the result.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 000c2be306c..3f01b5561a2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -76,7 +76,6 @@ object TreePattern extends Enumeration  {
   val PARAMETERIZED_QUERY: Value = Value
   val PIVOT: Value = Value
   val PLAN_EXPRESSION: Value = Value
-  val PYTHON_UDAF: Value = Value
   val PYTHON_UDF: Value = Value
   val REGEXP_EXTRACT_FAMILY: Value = Value
   val REGEXP_REPLACE: Value = Value
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index e1b6586bc9a..bc76eaed04b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -42,9 +42,16 @@ case class UserDefinedPythonFunction(
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the 
given input. */
   def apply(exprs: Column*): Column = {
-    builder(exprs.map(_.expr)) match {
+    fromUDFExpr(builder(exprs.map(_.expr)))
+  }
+
+  /**
+   * Returns a [[Column]] that will evaluate the UDF expression with the given 
input.
+   */
+  def fromUDFExpr(expr: Expression): Column = {
+    expr match {
       case udaf: PythonUDAF => Column(udaf.toAggregateExpression())
-      case udf => Column(udf)
+      case _ => Column(expr)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to