Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21650#discussion_r202861241 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala --- @@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } - private def canEvaluateInPython(e: PythonUDF): Boolean = { - e.children match { - // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) - // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = { + if (e.evalType != evalType) { + false + } else { + e.children match { + // single PythonUDF child could be chained and evaluated in Python + case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType) + // Python UDF can't be evaluated directly in JVM + case children => !children.exists(hasScalarPythonUDF) + } } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDF(_, evalType)) + } + + /** + * Collect evaluable UDFs from the current node. + * + * This function collects Python UDFs or Scalar Python UDFs from expressions of the input node, + * and returns a list of UDFs of the same eval type. + * + * If expressions contain both UDFs eval types, this function will only return Python UDFs. + * + * The caller should call this function multiple times until all evaluable UDFs are collected. + */ + private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = { + val pythonUDFs = + plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_BATCHED_UDF)) + + if (pythonUDFs.isEmpty) { + plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_SCALAR_PANDAS_UDF)) + } else { + pythonUDFs --- End diff -- I think it would be better to loop through the expressions and find the first scalar python udf, either `SQL_BATCHED_UDF` or `SQL_SCALAR_PANDAS_UDF` and then collect the rest of that type. This is really what is happening here so I think it would be more straightforward to do this in a single loop instead of 2 `flatMaps`.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org