This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fddf25a4dd8 [SPARK-43302][SQL][FOLLOWUP] Code cleanup for PythonUDAF fddf25a4dd8 is described below commit fddf25a4dd8029db89287416de39adb27e8643c8 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Wed May 17 13:39:27 2023 +0800 [SPARK-43302][SQL][FOLLOWUP] Code cleanup for PythonUDAF ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/40739 to do some code cleanup 1. remove the pattern `PYTHON_UDAF` as it's not used by any rule. 2. add `PythonFuncExpression.evalType` for convenience: catalyst rules (including third-party extensions) may want to get the eval type of a python function, no matter it's UDF or UDAF. 3. update the python profile to use `PythonUDAF.resultId` instead of `AggregateExpression.resultId`, to be consistent with `PythonUDF` ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #41142 from cloud-fan/follow. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Kent Yao <y...@apache.org> --- python/pyspark/sql/column.py | 4 ++++ python/pyspark/sql/udf.py | 12 +++++++----- .../apache/spark/sql/catalyst/expressions/PythonUDF.scala | 11 ++++++----- .../org/apache/spark/sql/catalyst/trees/TreePatterns.scala | 1 - .../sql/execution/python/UserDefinedPythonFunction.scala | 11 +++++++++-- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 49a42406048..3cf59989965 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -69,6 +69,10 @@ def _to_java_column(col: "ColumnOrName") -> JavaObject: return jcol +def _to_java_expr(col: "ColumnOrName") -> JavaObject: + return _to_java_column(col).expr() + + def _to_seq( sc: SparkContext, cols: Iterable["ColumnOrName"], diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 45828187295..87d53266edf 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -30,7 +30,7 @@ from py4j.java_gateway import JavaObject from pyspark import SparkContext from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq from pyspark.sql.types import ( ArrayType, BinaryType, @@ -419,8 +419,9 @@ class UserDefinedFunction: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) - id = jPythonUDF.expr().resultId().id() + jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr)) + jPythonUDF = judf.fromUDFExpr(jUDFExpr) + id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, profiler) else: # memory_profiler_enabled f = self.func @@ -436,8 +437,9 @@ class UserDefinedFunction: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) - id = jPythonUDF.expr().resultId().id() + jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr)) + jPythonUDF = judf.fromUDFExpr(jUDFExpr) + id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, memory_profiler) else: judf = self._judf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 08ffbea5510..8636eb61034 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -22,7 +22,7 @@ import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDAF, PYTHON_UDF, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -53,9 +53,12 @@ object PythonUDF { trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression { self: Expression => def name: String def func: PythonFunction + def evalType: Int def udfDeterministic: Boolean def resultId: ExprId + final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF) + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix" @@ -80,8 +83,6 @@ case class PythonUDF( lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( exprId = resultId) - final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF) - override lazy val canonicalized: Expression = { val canonicalizedChildren = children.map(_.canonicalized) // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. @@ -119,6 +120,8 @@ case class PythonUDAF( resultId: ExprId = NamedExpression.newExprId) extends UnevaluableAggregateFunc with PythonFuncExpression { + override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + override def sql(isDistinct: Boolean): String = { val distinct = if (isDistinct) "DISTINCT " else "" s"$name($distinct${children.mkString(", ")})" @@ -129,8 +132,6 @@ case class PythonUDAF( name + children.mkString(start, ", ", ")") + s"#${resultId.id}$typeSuffix" } - final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDAF) - override lazy val canonicalized: Expression = { val canonicalizedChildren = children.map(_.canonicalized) // `resultId` can be seen as cosmetic variation in PythonUDAF, as it doesn't affect the result. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 000c2be306c..3f01b5561a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -76,7 +76,6 @@ object TreePattern extends Enumeration { val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value - val PYTHON_UDAF: Value = Value val PYTHON_UDF: Value = Value val REGEXP_EXTRACT_FAMILY: Value = Value val REGEXP_REPLACE: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index e1b6586bc9a..bc76eaed04b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -42,9 +42,16 @@ case class UserDefinedPythonFunction( /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - builder(exprs.map(_.expr)) match { + fromUDFExpr(builder(exprs.map(_.expr))) + } + + /** + * Returns a [[Column]] that will evaluate the UDF expression with the given input. + */ + def fromUDFExpr(expr: Expression): Column = { + expr match { case udaf: PythonUDAF => Column(udaf.toAggregateExpression()) - case udf => Column(udf) + case _ => Column(expr) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org