Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21650#discussion_r199203674 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala --- @@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } - private def canEvaluateInPython(e: PythonUDF): Boolean = { - e.children match { - // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) - // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = { + if (e.evalType != evalType) { + false + } else { + e.children match { + // single PythonUDF child could be chained and evaluated in Python + case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType) + // Python UDF can't be evaluated directly in JVM + case children => !children.exists(hasScalarPythonUDF) + } } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDF(_, evalType)) + } + + /** + * Collect evaluable UDFs from the current node. + * + * This function collects Python UDFs or Scalar Python UDFs from expressions of the input node, + * and returns a list of UDFs of the same eval type. --- End diff -- I tried this on master and got the same exception: ``` >>> foo = pandas_udf(lambda x: x, 'v int', PandasUDFType.GROUPED_MAP) >>> df.select(foo(df['v'])).show() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/dataframe.py", line 353, in show print(self._jdf.showString(n, 20, vertical)) File "/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o257.showString. : java.lang.UnsupportedOperationException: Cannot evaluate expression: <lambda>(input[0, bigint, false]) at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261) at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) at scala.Option.getOrElse(Option.scala:121) ... ``` Therefore, this PR doesn't change that behavior. Both master and this PR don't extract non-scalar UDF in the expression.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org