This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 5b496e86b6 [GLUTEN-7760] Fix udf implicit cast & update doc (#7852)
5b496e86b6 is described below
commit 5b496e86b6e1fbca55866616fa1519a532b07b51
Author: Rong Ma <[email protected]>
AuthorDate: Fri Nov 8 18:44:50 2024 +0800
[GLUTEN-7760] Fix udf implicit cast & update doc (#7852)
---
.../spark/sql/hive/VeloxHiveUDFTransformer.scala | 8 +-
.../apache/gluten/expression/VeloxUdfSuite.scala | 27 +++--
docs/developers/VeloxUDF.md | 19 ++--
.../execution/WholeStageTransformerSuite.scala | 108 +-------------------
.../org/apache/spark/sql/GlutenQueryTest.scala | 109 ++++++++++++++++++++-
5 files changed, 142 insertions(+), 129 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
index d895faa317..b3524e20f0 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
@@ -37,11 +37,11 @@ object VeloxHiveUDFTransformer {
}
if (UDFResolver.UDFNames.contains(udfClassName)) {
- UDFResolver
+ val udfExpression = UDFResolver
.getUdfExpression(udfClassName, udfName)(expr.children)
- .getTransformer(
- ExpressionConverter.replaceWithExpressionTransformer(expr.children,
attributeSeq)
- )
+ udfExpression.getTransformer(
+
ExpressionConverter.replaceWithExpressionTransformer(udfExpression.children,
attributeSeq)
+ )
} else {
HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr,
attributeSeq)
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index f85103deb8..61ba927cd4 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -16,11 +16,13 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.execution.ProjectExecTransformer
import org.apache.gluten.tags.{SkipTestTags, UDFTest}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.expression.UDFResolver
import java.nio.file.Paths
@@ -158,16 +160,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
|AS 'org.apache.spark.sql.hive.execution.UDFStringString'
|""".stripMargin)
- val nativeResultWithImplicitConversion =
- spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM
$tbl""").collect()
- val nativeResult =
- spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM
$tbl""").collect()
+ val offloadWithImplicitConversionDF =
+ spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
+
checkGlutenOperatorMatch[ProjectExecTransformer](offloadWithImplicitConversionDF)
+ val offloadWithImplicitConversionResult =
offloadWithImplicitConversionDF.collect()
+
+ val offloadDF =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+ checkGlutenOperatorMatch[ProjectExecTransformer](offloadDF)
+ val offloadResult = offloadWithImplicitConversionDF.collect()
+
// Unregister native hive udf to fallback.
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
- val fallbackResult =
- spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM
$tbl""").collect()
-
assert(nativeResultWithImplicitConversion.sameElements(fallbackResult))
- assert(nativeResult.sameElements(fallbackResult))
+ val fallbackDF =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+ checkSparkOperatorMatch[ProjectExec](fallbackDF)
+ val fallbackResult = fallbackDF.collect()
+
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
+ assert(offloadResult.sameElements(fallbackResult))
// Add an unimplemented udf to the map to test fallback of
registered native hive udf.
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
@@ -176,6 +186,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
|AS
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
|""".stripMargin)
val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+ checkSparkOperatorMatch[ProjectExec](df)
checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tbl")
diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md
index 4f685cc41e..4cbdcfa992 100644
--- a/docs/developers/VeloxUDF.md
+++ b/docs/developers/VeloxUDF.md
@@ -172,22 +172,23 @@ or
Start `spark-sql` and run query. You need to add jar
"spark-hive_2.12-<spark.version>-tests.jar" to the classpath for hive udf
`org.apache.spark.sql.hive.execution.UDFStringString`
```
+spark-sql (default)> create table tbl as select * from values ('hello');
+Time taken: 3.656 seconds
spark-sql (default)> CREATE TEMPORARY FUNCTION hive_string_string AS
'org.apache.spark.sql.hive.execution.UDFStringString';
-Time taken: 0.808 seconds
-spark-sql (default)> select hive_string_string("hello", "world");
+Time taken: 0.047 seconds
+spark-sql (default)> select hive_string_string(col1, 'world') from tbl;
hello world
-Time taken: 3.208 seconds, Fetched 1 row(s)
+Time taken: 1.217 seconds, Fetched 1 row(s)
```
You can verify the offload with "explain".
```
-spark-sql (default)> explain select hive_string_string("hello", "world");
-== Physical Plan ==
-VeloxColumnarToRowExec
-+- ^(2) ProjectExecTransformer [hello world AS hive_string_string(hello,
world)#8]
- +- ^(2) InputIteratorTransformer[fake_column#9]
+spark-sql (default)> explain select hive_string_string(col1, 'world') from tbl;
+VeloxColumnarToRow
++- ^(2) ProjectExecTransformer
[HiveSimpleUDF#org.apache.spark.sql.hive.execution.UDFStringString(col1#11,world)
AS hive_string_string(col1, world)#12]
+ +- ^(2) InputIteratorTransformer[col1#11]
+- RowToVeloxColumnar
- +- *(1) Scan OneRowRelation[fake_column#9]
+ +- Scan hive spark_catalog.default.tbl [col1#11], HiveTableRelation
[`spark_catalog`.`default`.`tbl`,
org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [col1#11],
Partition Cols: []]
```
## Configurations
diff --git
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index 146d6fde58..fd250834d0 100644
---
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -17,15 +17,13 @@
package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
-import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.test.FallbackUtil
import org.apache.gluten.utils.Arm
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
-import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan,
UnaryExecNode}
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AdaptiveSparkPlanHelper, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.DoubleType
@@ -33,7 +31,6 @@ import java.io.File
import java.util.concurrent.atomic.AtomicBoolean
import scala.io.Source
-import scala.reflect.ClassTag
case class Table(name: String, partitionColumns: Seq[String])
@@ -179,109 +176,6 @@ abstract class WholeStageTransformerSuite
result
}
- def checkLengthAndPlan(df: DataFrame, len: Int = 100): Unit = {
- assert(df.collect().length == len)
- val executedPlan = getExecutedPlan(df)
- assert(executedPlan.exists(plan =>
plan.find(_.isInstanceOf[TransformSupport]).isDefined))
- }
-
- /**
- * Get all the children plan of plans.
- * @param plans:
- * the input plans.
- * @return
- */
- def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = {
- if (plans.isEmpty) {
- return Seq()
- }
-
- val inputPlans: Seq[SparkPlan] = plans.map {
- case stage: ShuffleQueryStageExec => stage.plan
- case plan => plan
- }
-
- var newChildren: Seq[SparkPlan] = Seq()
- inputPlans.foreach {
- plan =>
- newChildren = newChildren ++ getChildrenPlan(plan.children)
- // To avoid duplication of WholeStageCodegenXXX and its children.
- if (!plan.nodeName.startsWith("WholeStageCodegen")) {
- newChildren = newChildren :+ plan
- }
- }
- newChildren
- }
-
- /**
- * Get the executed plan of a data frame.
- * @param df:
- * dataframe.
- * @return
- * A sequence of executed plans.
- */
- def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = {
- df.queryExecution.executedPlan match {
- case exec: AdaptiveSparkPlanExec =>
- getChildrenPlan(Seq(exec.executedPlan))
- case cmd: CommandResultExec =>
- getChildrenPlan(Seq(cmd.commandPhysicalPlan))
- case plan =>
- getChildrenPlan(Seq(plan))
- }
- }
-
- /**
- * Check whether the executed plan of a dataframe contains the expected plan.
- * @param df:
- * the input dataframe.
- * @param tag:
- * class of the expected plan.
- * @tparam T:
- * type of the expected plan.
- */
- def checkGlutenOperatorMatch[T <: GlutenPlan](df: DataFrame)(implicit tag:
ClassTag[T]): Unit = {
- val executedPlan = getExecutedPlan(df)
- assert(
- executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)),
- s"Expect ${tag.runtimeClass.getSimpleName} exists " +
- s"in executedPlan:\n ${executedPlan.last}"
- )
- }
-
- def checkSparkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag:
ClassTag[T]): Unit = {
- val executedPlan = getExecutedPlan(df)
- assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
- }
-
- /**
- * Check whether the executed plan of a dataframe contains the expected plan
chain.
- *
- * @param df
- * : the input dataframe.
- * @param tag
- * : class of the expected plan.
- * @param childTag
- * : class of the expected plan's child.
- * @tparam T
- * : type of the expected plan.
- * @tparam PT
- * : type of the expected plan's child.
- */
- def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
- df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit
= {
- val executedPlan = getExecutedPlan(df)
- assert(
- executedPlan.exists(
- plan =>
- tag.runtimeClass.isInstance(plan)
- && childTag.runtimeClass.isInstance(plan.children.head)),
- s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
- + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
- + s"${executedPlan.last}"
- )
- }
-
/**
* run a query with native engine as well as vanilla spark then compare the
result set for
* correctness check
diff --git
a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
index 53abaa9ac2..164083a8d8 100644
--- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
+++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
@@ -21,6 +21,8 @@ package org.apache.spark.sql
* 1. We need to modify the way org.apache.spark.sql.CHQueryTest#compare
compares double
*/
import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.execution.TransformSupport
+import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SPARK_VERSION_SHORT
@@ -28,7 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan,
SQLExecution, UnaryExecNode}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.storage.StorageLevel
@@ -38,6 +41,7 @@ import org.scalatest.Assertions
import java.util.TimeZone
import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import scala.reflect.runtime.universe
abstract class GlutenQueryTest extends PlanTest {
@@ -306,6 +310,109 @@ abstract class GlutenQueryTest extends PlanTest {
query.queryExecution.executedPlan.missingInput.isEmpty,
s"The physical plan has missing
inputs:\n${query.queryExecution.executedPlan}")
}
+
+ def checkLengthAndPlan(df: DataFrame, len: Int = 100): Unit = {
+ assert(df.collect().length == len)
+ val executedPlan = getExecutedPlan(df)
+ assert(executedPlan.exists(plan =>
plan.find(_.isInstanceOf[TransformSupport]).isDefined))
+ }
+
+ /**
+ * Get all the children plan of plans.
+ * @param plans:
+ * the input plans.
+ * @return
+ */
+ def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = {
+ if (plans.isEmpty) {
+ return Seq()
+ }
+
+ val inputPlans: Seq[SparkPlan] = plans.map {
+ case stage: ShuffleQueryStageExec => stage.plan
+ case plan => plan
+ }
+
+ var newChildren: Seq[SparkPlan] = Seq()
+ inputPlans.foreach {
+ plan =>
+ newChildren = newChildren ++ getChildrenPlan(plan.children)
+ // To avoid duplication of WholeStageCodegenXXX and its children.
+ if (!plan.nodeName.startsWith("WholeStageCodegen")) {
+ newChildren = newChildren :+ plan
+ }
+ }
+ newChildren
+ }
+
+ /**
+ * Get the executed plan of a data frame.
+ * @param df:
+ * dataframe.
+ * @return
+ * A sequence of executed plans.
+ */
+ def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = {
+ df.queryExecution.executedPlan match {
+ case exec: AdaptiveSparkPlanExec =>
+ getChildrenPlan(Seq(exec.executedPlan))
+ case cmd: CommandResultExec =>
+ getChildrenPlan(Seq(cmd.commandPhysicalPlan))
+ case plan =>
+ getChildrenPlan(Seq(plan))
+ }
+ }
+
+ /**
+ * Check whether the executed plan of a dataframe contains the expected plan
chain.
+ *
+ * @param df
+ * : the input dataframe.
+ * @param tag
+ * : class of the expected plan.
+ * @param childTag
+ * : class of the expected plan's child.
+ * @tparam T
+ * : type of the expected plan.
+ * @tparam PT
+ * : type of the expected plan's child.
+ */
+ def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
+ df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit
= {
+ val executedPlan = getExecutedPlan(df)
+ assert(
+ executedPlan.exists(
+ plan =>
+ tag.runtimeClass.isInstance(plan)
+ && childTag.runtimeClass.isInstance(plan.children.head)),
+ s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
+ + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
+ + s"${executedPlan.last}"
+ )
+ }
+
+ /**
+ * Check whether the executed plan of a dataframe contains the expected plan.
+ * @param df:
+ * the input dataframe.
+ * @param tag:
+ * class of the expected plan.
+ * @tparam T:
+ * type of the expected plan.
+ */
+ def checkGlutenOperatorMatch[T <: GlutenPlan](df: DataFrame)(implicit tag:
ClassTag[T]): Unit = {
+ val executedPlan = getExecutedPlan(df)
+ assert(
+ executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)),
+ s"Expect ${tag.runtimeClass.getSimpleName} exists " +
+ s"in executedPlan:\n ${executedPlan.last}"
+ )
+ }
+
+ def checkSparkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag:
ClassTag[T]): Unit = {
+ val executedPlan = getExecutedPlan(df)
+ assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
+ }
}
object GlutenQueryTest extends Assertions {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]