Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21650#discussion_r205262719 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala --- @@ -94,36 +95,94 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private case class LazyEvalType(var evalType: Int = -1) { + + def isSet: Boolean = evalType >= 0 + + def set(evalType: Int): Unit = { + if (isSet) { + throw new IllegalStateException("Eval type has already been set") + } else { + this.evalType = evalType + } + } + + def get(): Int = { + if (!isSet) { + throw new IllegalStateException("Eval type is not set") + } else { + evalType + } + } + } + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } - private def canEvaluateInPython(e: PythonUDF): Boolean = { - e.children match { - // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) - // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + /** + * Check whether a PythonUDF expression can be evaluated in Python. + * + * If the lazy eval type is not set, this method checks for either Batched Python UDF and Scalar + * Pandas UDF. If the lazy eval type is set, this method checks for the expression of the + * specified eval type. + * + * This method will also set the lazy eval type to be the type of the first evaluable expression, + * i.e., if lazy eval type is not set and we find a evaluable Python UDF expression, lazy eval + * type will be set to the eval type of the expression. + * + */ + private def canEvaluateInPython(e: PythonUDF, lazyEvalType: LazyEvalType): Boolean = { --- End diff -- Bryan, I tried to apply your implementation and the simple test also fail: ``` @udf('int') def f1(x): assert type(x) == int return x + 1 @pandas_udf('int') def f2(x): assert type(x) == pd.Series return x + 10 df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) ``` Do you mind trying this too? Hopefully I didn't do something silly here..
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org