Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21650#discussion_r205872386 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala --- @@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private type EvalType = Int + private type EvalTypeChecker = EvalType => Boolean + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { e.children match { // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + case children => !children.exists(hasScalarPythonUDF) } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { + // Eval type checker is set once when we find the first evaluable UDF and its value + // shouldn't change later. + // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only + // extract UDFs of the same eval type) + var evalTypeChecker: Option[EvalTypeChecker] = None + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.isEmpty => + evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) + Seq(udf) --- End diff -- @HyukjinKwon In your code this line is `collectEvaluableUDFs (udf)`. I think we should just return `Seq(udf)` to avoid checking the expression twice.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org