Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21650#discussion_r204429892
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 ---
    @@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
      */
     object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
     
    -  private def hasPythonUDF(e: Expression): Boolean = {
    +  private def hasScalarPythonUDF(e: Expression): Boolean = {
         e.find(PythonUDF.isScalarPythonUDF).isDefined
       }
     
    -  private def canEvaluateInPython(e: PythonUDF): Boolean = {
    -    e.children match {
    -      // single PythonUDF child could be chained and evaluated in Python
    -      case Seq(u: PythonUDF) => canEvaluateInPython(u)
    -      // Python UDF can't be evaluated directly in JVM
    -      case children => !children.exists(hasPythonUDF)
    +  private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = {
    +    if (e.evalType != evalType) {
    +      false
    +    } else {
    +      e.children match {
    +        // single PythonUDF child could be chained and evaluated in Python
    +        case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType)
    +        // Python UDF can't be evaluated directly in JVM
    +        case children => !children.exists(hasScalarPythonUDF)
    +      }
         }
       }
     
    -  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = 
expr match {
    -    case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && 
canEvaluateInPython(udf) => Seq(udf)
    -    case e => e.children.flatMap(collectEvaluatableUDF)
    +  private def collectEvaluableUDF(expr: Expression, evalType: Int): 
Seq[PythonUDF] = expr match {
    +    case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && 
canEvaluateInPython(udf, evalType) =>
    +      Seq(udf)
    +    case e => e.children.flatMap(collectEvaluableUDF(_, evalType))
    +  }
    +
    +  /**
    +   * Collect evaluable UDFs from the current node.
    +   *
    +   * This function collects Python UDFs or Scalar Python UDFs from 
expressions of the input node,
    +   * and returns a list of UDFs of the same eval type.
    +   *
    +   * If expressions contain both UDFs eval types, this function will only 
return Python UDFs.
    +   *
    +   * The caller should call this function multiple times until all 
evaluable UDFs are collected.
    +   */
    +  private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = {
    +    val pythonUDFs =
    +      plan.expressions.flatMap(collectEvaluableUDF(_, 
PythonEvalType.SQL_BATCHED_UDF))
    +
    +    if (pythonUDFs.isEmpty) {
    +      plan.expressions.flatMap(collectEvaluableUDF(_, 
PythonEvalType.SQL_SCALAR_PANDAS_UDF))
    +    } else {
    +      pythonUDFs
    --- End diff --
    
    What you said makes sense and that's actually my first attempt but end up 
being pretty complicated. The issue is that it is hard to do a one traversal of 
the expression tree to find the UDFs because we need to pass the evalType to 
all subtree and the result of one subtree can affect the result of another 
(i.e, if we find one type of UDF in one subtree, we need to pass the type to 
all other subtree because they must agree on evalType), this makes the code 
more complicated...
    
    Another way is to do two traversals where in the first traversal, we look 
for eval type and in the second traversal, we look for UDFs of the eval type, 
but this isn't much different from what I have now in terms of efficiency and I 
find the current logic is simpler and less likely to have bugs. I actually 
tried these approaches and found the current way to be the easiest to implement 
and least likely to have bugs.
    
    WDYT?



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to