This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new cee7dcc48b [GLUTEN-11088][VL] Add raise_error restriction in spark 4.0
and enable ScalarFunctionsValidateSuite (#11170)
cee7dcc48b is described below
commit cee7dcc48b27519819e58d39a0e30c5f5ca281c8
Author: Rong Ma <[email protected]>
AuthorDate: Wed Nov 26 09:26:16 2025 +0000
[GLUTEN-11088][VL] Add raise_error restriction in spark 4.0 and enable
ScalarFunctionsValidateSuite (#11170)
Add raise_error restriction in spark 4.0 and fix
ScalarFunctionsValidateSuite
---
.../gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala | 10 ++++++++++
.../org/apache/gluten/expression/ExpressionRestrictions.scala | 9 +++++++++
.../apache/gluten/functions/ScalarFunctionsValidateSuite.scala | 3 +--
.../scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala | 4 ++++
.../org/apache/gluten/expression/ExpressionConverter.scala | 7 +++++++
.../main/scala/org/apache/gluten/sql/shims/SparkShims.scala | 5 ++++-
.../org/apache/gluten/sql/shims/spark32/Spark32Shims.scala | 5 ++++-
.../org/apache/gluten/sql/shims/spark33/Spark33Shims.scala | 3 +++
.../org/apache/gluten/sql/shims/spark34/Spark34Shims.scala | 4 ++++
.../org/apache/gluten/sql/shims/spark35/Spark35Shims.scala | 3 +++
.../org/apache/gluten/sql/shims/spark40/Spark40Shims.scala | 10 ++++++++++
11 files changed, 59 insertions(+), 4 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 20f52997ab..fcd9b682a4 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -1073,4 +1073,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
original: MonthsBetween): ExpressionTransformer = {
MonthsBetweenTransformer(substraitExprName, date1, date2, roundOff,
original)
}
+
+ override def getErrorMessage(raiseError: RaiseError): Expression = {
+ SparkShimLoader.getSparkShims.getErrorMessage(raiseError) match {
+ case Some(msg) => msg
+ case None =>
+ GlutenExceptionUtil.throwsNotFullySupported(
+ ExpressionNames.RAISE_ERROR,
+ RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE)
+ }
+ }
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionRestrictions.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionRestrictions.scala
index af16e5ed17..04f776fa42 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionRestrictions.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionRestrictions.scala
@@ -92,6 +92,15 @@ object Base64Restrictions extends ExpressionRestrictions {
override val restrictionMessages: Array[String] =
Array(NOT_SUPPORT_DISABLE_CHUNK_BASE64_STRING)
}
+object RaiseErrorRestrictions extends ExpressionRestrictions {
+ val ONLY_SUPPORT_ERROR_MESSAGE: String =
+ s"Only 'errorMessage' is supported as the second argument in
${ExpressionNames.RAISE_ERROR}"
+
+ override val functionName: String = ExpressionNames.RAISE_ERROR
+
+ override val restrictionMessages: Array[String] =
Array(ONLY_SUPPORT_ERROR_MESSAGE)
+}
+
object ExpressionRestrictions {
// Called by gen-function-support-docs.py to get all restrictions.
def listAllRestrictions(): Array[ExpressionRestrictions] = {
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/functions/ScalarFunctionsValidateSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/functions/ScalarFunctionsValidateSuite.scala
index febd7da4f6..71a0f88fc5 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/functions/ScalarFunctionsValidateSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/functions/ScalarFunctionsValidateSuite.scala
@@ -522,8 +522,7 @@ abstract class ScalarFunctionsValidateSuite extends
FunctionsValidateSuite {
}
}
- // TODO: fix on spark-4.0
- testWithMaxSparkVersion("raise_error, assert_true", "3.5") {
+ test("raise_error, assert_true") {
runQueryAndCompare("""SELECT assert_true(l_orderkey >= 1), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenPlan[ProjectExecTransformer]
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 2ee41f0d8d..a8ffca51ba 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -812,4 +812,8 @@ trait SparkPlanExecApi {
def isRowIndexMetadataColumn(columnName: String): Boolean = {
SparkShimLoader.getSparkShims.isRowIndexMetadataColumn(columnName)
}
+
+ def getErrorMessage(raiseError: RaiseError): Expression = {
+ throw new GlutenNotSupportException(s"${ExpressionNames.RAISE_ERROR} is
not supported")
+ }
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 41a2a2ff82..c7ee159e01 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -842,6 +842,13 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
ce,
attributeSeq,
expressionsMap)
+ case re: RaiseError =>
+ val errorMessage =
+ BackendsApiManager.getSparkPlanExecApiInstance.getErrorMessage(re)
+ GenericExpressionTransformer(
+ substraitExprName,
+ replaceWithExpressionTransformer0(errorMessage, attributeSeq,
expressionsMap),
+ re)
case expr =>
GenericExpressionTransformer(
substraitExprName,
diff --git
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index bbbd665cd8..bb7b55fdae 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression,
Expression, UnBase64}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression,
Expression, RaiseError, UnBase64}
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -325,4 +325,7 @@ trait SparkShims {
def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType
def getRewriteCreateTableAsSelect(session: SparkSession): SparkStrategy = _
=> Seq.empty
+
+ /** Shim method for get the "errorMessage" value for Spark 4.0 and above */
+ def getErrorMessage(raiseError: RaiseError): Option[Expression]
}
diff --git
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
index baf7868f32..007a962884 100644
---
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
+++
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BinaryExpression, Expression, InputFileBlockLength,
InputFileBlockStart, InputFileName}
+import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -315,4 +315,7 @@ class Spark32Shims extends SparkShims {
DecimalPrecision.widerDecimalType(d1, d2)
}
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ Some(raiseError.child)
+ }
}
diff --git
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
index 0ad1d94e79..2c345c65aa 100644
---
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
+++
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
@@ -409,4 +409,7 @@ class Spark33Shims extends SparkShims {
DecimalPrecision.widerDecimalType(d1, d2)
}
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ Some(raiseError.child)
+ }
}
diff --git
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index c097833435..890ee8c592 100644
---
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -651,4 +651,8 @@ class Spark34Shims extends SparkShims {
override def getRewriteCreateTableAsSelect(session: SparkSession):
SparkStrategy = {
RewriteCreateTableAsSelect(session)
}
+
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ Some(raiseError.child)
+ }
}
diff --git
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index 58759949f7..ee52e8b2b9 100644
---
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -715,4 +715,7 @@ class Spark35Shims extends SparkShims {
DecimalPrecision.widerDecimalType(d1, d2)
}
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ Some(raiseError.child)
+ }
}
diff --git
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
index 9445eaf9cc..636babc981 100644
---
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
+++
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
@@ -727,4 +727,14 @@ class Spark40Shims extends SparkShims {
DecimalPrecisionTypeCoercion.widerDecimalType(d1, d2)
}
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ raiseError.errorParms match {
+ case CreateMap(children, _)
+ if children.size == 2 && children.head.isInstanceOf[Literal]
+ && children.head.asInstanceOf[Literal].value.toString ==
"errorMessage" =>
+ Some(children(1))
+ case _ =>
+ None
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]