Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21650#discussion_r205872386
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 ---
    @@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
      */
     object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
     
    -  private def hasPythonUDF(e: Expression): Boolean = {
    +  private type EvalType = Int
    +  private type EvalTypeChecker = EvalType => Boolean
    +
    +  private def hasScalarPythonUDF(e: Expression): Boolean = {
         e.find(PythonUDF.isScalarPythonUDF).isDefined
       }
     
       private def canEvaluateInPython(e: PythonUDF): Boolean = {
         e.children match {
           // single PythonUDF child could be chained and evaluated in Python
    -      case Seq(u: PythonUDF) => canEvaluateInPython(u)
    +      case Seq(u: PythonUDF) => e.evalType == u.evalType && 
canEvaluateInPython(u)
           // Python UDF can't be evaluated directly in JVM
    -      case children => !children.exists(hasPythonUDF)
    +      case children => !children.exists(hasScalarPythonUDF)
         }
       }
     
    -  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = 
expr match {
    -    case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && 
canEvaluateInPython(udf) => Seq(udf)
    -    case e => e.children.flatMap(collectEvaluatableUDF)
    +  private def collectEvaluableUDFsFromExpressions(expressions: 
Seq[Expression]): Seq[PythonUDF] = {
    +    // Eval type checker is set once when we find the first evaluable UDF 
and its value
    +    // shouldn't change later.
    +    // Used to check if subsequent UDFs are of the same type as the first 
UDF. (since we can only
    +    // extract UDFs of the same eval type)
    +    var evalTypeChecker: Option[EvalTypeChecker] = None
    +
    +    def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr 
match {
    +      case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && 
canEvaluateInPython(udf)
    +        && evalTypeChecker.isEmpty =>
    +        evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType 
== udf.evalType)
    +        Seq(udf)
    --- End diff --
    
    @HyukjinKwon In your code this line is `collectEvaluableUDFs (udf)`. I 
think we should just return `Seq(udf)` to avoid checking the expression twice.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to