This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new a4b69a2f1 [CORE] Add decimal precision tests (#5752) a4b69a2f1 is described below commit a4b69a2f141bbc9eede669cc337112845cd712ef Author: Xiduo You <ulyssesyo...@gmail.com> AuthorDate: Fri May 17 14:06:03 2024 +0800 [CORE] Add decimal precision tests (#5752) * Add decimal precision tests * fix ck test * fix * fix --------- Co-authored-by: Kent Yao <y...@apache.org> --- .../expression/CHExpressionTransformer.scala | 10 +- .../gluten/expression/ExpressionTransformer.scala | 10 +- .../apache/spark/sql/expression/UDFResolver.scala | 32 +++-- .../gluten/backendsapi/SparkPlanExecApi.scala | 8 -- .../expression/ArrayExpressionTransformer.scala | 4 +- .../gluten/expression/ConditionalTransformer.scala | 4 +- .../DateTimeExpressionsTransformer.scala | 10 +- .../gluten/expression/ExpressionConverter.scala | 7 -- .../gluten/expression/ExpressionTransformer.scala | 9 ++ .../expression/GenericExpressionTransformer.scala | 2 +- .../expression/HashExpressionTransformer.scala | 2 +- .../JsonTupleExpressionTransformer.scala | 2 +- .../expression/LambdaFunctionTransformer.scala | 2 +- .../gluten/expression/LiteralTransformer.scala | 4 +- .../expression/MapExpressionTransformer.scala | 4 +- .../expression/NamedExpressionsTransformer.scala | 2 +- .../PredicateExpressionTransformer.scala | 9 +- .../expression/ScalarSubqueryTransformer.scala | 3 +- .../expression/StringExpressionTransformer.scala | 2 +- .../expression/StructExpressionTransformer.scala | 2 +- .../expression/UnaryExpressionTransformer.scala | 17 +-- .../utils/clickhouse/ClickHouseTestSettings.scala | 1 + .../gluten/utils/velox/VeloxTestSettings.scala | 1 + .../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++ .../extension/CustomerExpressionTransformer.scala | 4 +- .../utils/clickhouse/ClickHouseTestSettings.scala | 1 + .../gluten/utils/velox/VeloxTestSettings.scala | 3 +- .../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++ .../utils/clickhouse/ClickHouseTestSettings.scala | 1 + .../gluten/utils/velox/VeloxTestSettings.scala | 3 +- .../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++ .../utils/clickhouse/ClickHouseTestSettings.scala | 1 + .../gluten/utils/velox/VeloxTestSettings.scala | 3 +- .../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++ 34 files changed, 639 insertions(+), 76 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala index 7d9dbaddc..98cc4a930 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala @@ -34,7 +34,7 @@ case class CHSizeExpressionTransformer( substraitExprName: String, child: ExpressionTransformer, original: Size) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // Pass legacyLiteral as second argument in substrait function @@ -51,7 +51,7 @@ case class CHTruncTimestampTransformer( timestamp: ExpressionTransformer, timeZoneId: Option[String] = None, original: TruncTimestamp) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // The format must be constant string in the function date_trunc of ch. @@ -126,7 +126,7 @@ case class CHStringTranslateTransformer( matchingExpr: ExpressionTransformer, replaceExpr: ExpressionTransformer, original: StringTranslate) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // In CH, translateUTF8 requires matchingExpr and replaceExpr argument have the same length @@ -158,7 +158,7 @@ case class CHPosExplodeTransformer( child: ExpressionTransformer, original: PosExplode, attributeSeq: Seq[Attribute]) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode: ExpressionNode = child.doTransform(args) @@ -202,7 +202,7 @@ case class CHRegExpReplaceTransformer( substraitExprName: String, children: Seq[ExpressionTransformer], original: RegExpReplace) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // In CH: replaceRegexpAll(subject, regexp, rep), which is equivalent diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index 75a2c3a62..da8433fa2 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -35,7 +35,7 @@ case class VeloxAliasTransformer( substraitExprName: String, child: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { child.doTransform(args) @@ -46,7 +46,7 @@ case class VeloxNamedStructTransformer( substraitExprName: String, original: CreateNamedStruct, attributeSeq: Seq[Attribute]) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: Object): ExpressionNode = { val expressionNodes = Lists.newArrayList[ExpressionNode]() original.valExprs.foreach( @@ -67,7 +67,7 @@ case class VeloxGetStructFieldTransformer( childTransformer: ExpressionTransformer, ordinal: Int, original: GetStructField) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: Object): ExpressionNode = { val childNode = childTransformer.doTransform(args) childNode match { @@ -86,7 +86,7 @@ case class VeloxHashExpressionTransformer( substraitExprName: String, exps: Seq[ExpressionTransformer], original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // As of Spark 3.3, there are 3 kinds of HashExpression. // HiveHash is not supported in native backend and will fail native validation. @@ -121,7 +121,7 @@ case class VeloxStringSplitTransformer( regexExpr: ExpressionTransformer, limitExpr: ExpressionTransformer, original: StringSplit) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { if ( diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index bdfd24ed5..847e5a2e6 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.velox.VeloxBackendSettings import org.apache.gluten.exception.GlutenException import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, Transformable} import org.apache.gluten.expression.ConverterUtils.FunctionConfig -import org.apache.gluten.substrait.expression.ExpressionBuilder +import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.gluten.udf.UdfJniWrapper import org.apache.gluten.vectorized.JniWorkspace @@ -110,18 +110,24 @@ case class UDFExpression( this.getClass.getSimpleName + ": getTransformer called before children transformer initialized.") } - (args: Object) => { - val transformers = childrenTransformers.map(_.doTransform(args)) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, - ConverterUtils.makeFuncName(name, children.map(_.dataType), FunctionConfig.REQ)) - - val typeNode = ConverterUtils.getTypeNode(dataType, nullable) - ExpressionBuilder.makeScalarFunction( - functionId, - Lists.newArrayList(transformers: _*), - typeNode) + + val localDataType = dataType + new ExpressionTransformer { + override def doTransform(args: Object): ExpressionNode = { + val transformers = childrenTransformers.map(_.doTransform(args)) + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionId = ExpressionBuilder.newScalarFunction( + functionMap, + ConverterUtils.makeFuncName(name, children.map(_.dataType), FunctionConfig.REQ)) + + val typeNode = ConverterUtils.getTypeNode(dataType, nullable) + ExpressionBuilder.makeScalarFunction( + functionId, + Lists.newArrayList(transformers: _*), + typeNode) + } + + override def dataType: DataType = localDataType } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 8df74bb88..aa27d1ce1 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -451,14 +451,6 @@ trait SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, children, original) } - def genEqualNullSafeTransformer( - substraitExprName: String, - left: ExpressionTransformer, - right: ExpressionTransformer, - original: EqualNullSafe): ExpressionTransformer = { - GenericExpressionTransformer(substraitExprName, Seq(left, right), original) - } - def genMd5Transformer( substraitExprName: String, child: ExpressionTransformer, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala index 85a1f58fb..68a464f13 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala @@ -33,7 +33,7 @@ case class CreateArrayTransformer( children: Seq[ExpressionTransformer], useStringTypeWhenEmpty: Boolean, original: CreateArray) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // If children is empty, @@ -62,7 +62,7 @@ case class GetArrayItemTransformer( right: ExpressionTransformer, failOnError: Boolean, original: GetArrayItem) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // Ignore failOnError for clickhouse backend diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala index 18a46d7ca..0fdd68511 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala @@ -27,7 +27,7 @@ case class CaseWhenTransformer( branches: Seq[(ExpressionTransformer, ExpressionTransformer)], elseValue: Option[ExpressionTransformer], original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // generate branches nodes @@ -52,7 +52,7 @@ case class IfTransformer( trueValue: ExpressionTransformer, falseValue: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val ifNodes = new JArrayList[ExpressionNode] diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala index 797dc81d3..66004291a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala @@ -36,7 +36,7 @@ case class ExtractDateTransformer( substraitExprName: String, child: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = child.doTransform(args) @@ -65,7 +65,7 @@ case class DateDiffTransformer( endDate: ExpressionTransformer, startDate: ExpressionTransformer, original: DateDiff) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val endDateNode = endDate.doTransform(args) @@ -99,7 +99,7 @@ case class ToUnixTimestampTransformer( timeZoneId: Option[String], failOnError: Boolean, original: ToUnixTimestamp) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val dataTypes = Seq(original.timeExp.dataType, StringType) @@ -124,7 +124,7 @@ case class TruncTimestampTransformer( timestamp: ExpressionTransformer, timeZoneId: Option[String] = None, original: TruncTimestamp) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val timestampNode = timestamp.doTransform(args) @@ -160,7 +160,7 @@ case class MonthsBetweenTransformer( roundOff: ExpressionTransformer, timeZoneId: Option[String] = None, original: MonthsBetween) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val date1Node = date1.doTransform(args) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b7b946268..e692890c4 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -415,13 +415,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { ), r ) - case equal: EqualNullSafe => - BackendsApiManager.getSparkPlanExecApiInstance.genEqualNullSafeTransformer( - substraitExprName, - replaceWithExpressionTransformerInternal(equal.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(equal.right, attributeSeq, expressionsMap), - equal - ) case md5: Md5 => BackendsApiManager.getSparkPlanExecApiInstance.genMd5Transformer( substraitExprName, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index 65badcbae..6b6587862 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -18,6 +18,15 @@ package org.apache.gluten.expression import org.apache.gluten.substrait.expression.ExpressionNode +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.DataType + trait ExpressionTransformer { def doTransform(args: java.lang.Object): ExpressionNode + def dataType: DataType +} + +trait ExpressionTransformerWithOrigin extends ExpressionTransformer { + def original: Expression + def dataType: DataType = original.dataType } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala index 62afcad28..8faf4965f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala @@ -27,7 +27,7 @@ case class GenericExpressionTransformer( substraitExprName: String, children: Seq[ExpressionTransformer], original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: Object): ExpressionNode = { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] val functionId = ExpressionBuilder.newScalarFunction( diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala index d813f8250..28f2dda01 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala @@ -25,7 +25,7 @@ case class HashExpressionTransformer( substraitExprName: String, exps: Seq[ExpressionTransformer], original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val nodes = new java.util.ArrayList[ExpressionNode]() diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala index 25e3e12a5..e8ff3d360 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala @@ -28,7 +28,7 @@ case class JsonTupleExpressionTransformer( substraitExprName: String, children: Seq[ExpressionTransformer], original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: Object): ExpressionNode = { val jsonExpr = children.head diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala index 492de2b76..ce6d13a95 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala @@ -27,7 +27,7 @@ case class LambdaFunctionTransformer( arguments: Seq[ExpressionTransformer], hidden: Boolean = false, original: LambdaFunction) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: Object): ExpressionNode = { // Need to fallback when hidden be true as it's not supported in Velox diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala index 05787858e..8fb9943d6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala @@ -20,9 +20,9 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode import org.apache.spark.sql.catalyst.expressions._ -case class LiteralTransformer(lit: Literal) extends ExpressionTransformer { +case class LiteralTransformer(original: Literal) extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { - ExpressionBuilder.makeLiteral(lit.value, lit.dataType, lit.nullable) + ExpressionBuilder.makeLiteral(original.value, original.dataType, original.nullable) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala index e136f1b3a..c09afaebc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala @@ -30,7 +30,7 @@ case class CreateMapTransformer( children: Seq[ExpressionTransformer], useStringTypeWhenEmpty: Boolean, original: CreateMap) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { // If children is empty, @@ -64,7 +64,7 @@ case class GetMapValueTransformer( key: ExpressionTransformer, failOnError: Boolean, original: GetMapValue) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { if (BackendsApiManager.getSettings.alwaysFailOnMapExpression()) { diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala index 70ad13584..2af4a5fa2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala @@ -28,7 +28,7 @@ case class AliasTransformer( substraitExprName: String, child: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = child.doTransform(args) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala index dfa4ceed6..7d34466e5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala @@ -32,7 +32,7 @@ case class InTransformer( list: Seq[Expression], valueType: DataType, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { assert(list.forall(_.foldable)) // Stores the values in a List Literal. @@ -46,7 +46,7 @@ case class InSetTransformer( hset: Set[Any], valueType: DataType, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { InExpressionTransformer.toTransformer(value.doTransform(args), hset, valueType) } @@ -74,7 +74,7 @@ case class LikeTransformer( left: ExpressionTransformer, right: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val leftNode = left.doTransform(args) val rightNode = right.doTransform(args) @@ -108,7 +108,8 @@ case class DecimalArithmeticExpressionTransformer( right: ExpressionTransformer, resultType: DecimalType, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { + override def dataType: DataType = resultType override def doTransform(args: java.lang.Object): ExpressionNode = { val leftNode = left.doTransform(args) val rightNode = right.doTransform(args) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala index 534bde3b3..4f5a43d47 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BaseSubqueryExec, ScalarSubquery} case class ScalarSubqueryTransformer(plan: BaseSubqueryExec, exprId: ExprId, query: ScalarSubquery) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { + override def original: Expression = query override def doTransform(args: java.lang.Object): ExpressionNode = { // don't trigger collect when in validation phase diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala index da021be24..b31d66b68 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala @@ -28,7 +28,7 @@ case class String2TrimExpressionTransformer( trimStr: Option[ExpressionTransformer], srcStr: ExpressionTransformer, original: Expression) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val trimStrNode = trimStr.map(_.doTransform(args)) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala index c70395a7d..616971b6d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala @@ -29,7 +29,7 @@ case class GetStructFieldTransformer( childTransformer: ExpressionTransformer, ordinal: Int, original: GetStructField) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = childTransformer.doTransform(args) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala index 2d3840ce4..d0ac19b4a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala @@ -32,17 +32,18 @@ case class ChildTransformer(child: ExpressionTransformer) extends ExpressionTran override def doTransform(args: java.lang.Object): ExpressionNode = { child.doTransform(args) } + override def dataType: DataType = child.dataType } case class CastTransformer( child: ExpressionTransformer, - datatype: DataType, + dataType: DataType, timeZoneId: Option[String], original: Cast) extends ExpressionTransformer { override def doTransform(args: java.lang.Object): ExpressionNode = { - val typeNode = ConverterUtils.getTypeNode(datatype, original.nullable) + val typeNode = ConverterUtils.getTypeNode(dataType, original.nullable) ExpressionBuilder.makeCast(typeNode, child.doTransform(args), original.ansiEnabled) } } @@ -51,7 +52,7 @@ case class ExplodeTransformer( substraitExprName: String, child: ExpressionTransformer, original: Explode) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode: ExpressionNode = child.doTransform(args) @@ -79,7 +80,7 @@ case class PosExplodeTransformer( child: ExpressionTransformer, original: PosExplode, attributeSeq: Seq[Attribute]) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode: ExpressionNode = child.doTransform(args) @@ -154,7 +155,7 @@ case class CheckOverflowTransformer( child: ExpressionTransformer, childResultType: DataType, original: CheckOverflow) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { BackendsApiManager.getTransformerApiInstance.createCheckOverflowExprNode( @@ -172,7 +173,7 @@ case class MakeDecimalTransformer( substraitExprName: String, child: ExpressionTransformer, original: MakeDecimal) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = child.doTransform(args) @@ -202,7 +203,7 @@ case class RandTransformer( substraitExprName: String, explicitSeed: ExpressionTransformer, original: Rand) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { if (!original.hideSeed) { @@ -226,7 +227,7 @@ case class GetArrayStructFieldsTransformer( numFields: Int, containsNull: Boolean, original: GetArrayStructFields) - extends ExpressionTransformer { + extends ExpressionTransformerWithOrigin { override def doTransform(args: java.lang.Object): ExpressionNode = { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index bc0410834..afc427cd3 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -756,6 +756,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeGlutenTest("to_unix_timestamp") .excludeGlutenTest("Hour") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] .exclude("sha2") .exclude("murmur3/xxHash64/hive hash: struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 366796a57..5e3591203 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -226,6 +226,7 @@ class VeloxTestSettings extends BackendTestSettings { // Replaced by a gluten test to pass timezone through config. .exclude("from_unixtime") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenStringFunctionsSuite] enableSuite[GlutenRegexpExpressionsSuite] enableSuite[GlutenNullExpressionsSuite] diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala new file mode 100644 index 000000000..97e752d7d --- /dev/null +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.gluten.expression._ + +import org.apache.spark.sql.GlutenTestsTrait +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types._ + +class GlutenDecimalPrecisionSuite extends GlutenTestsTrait { + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry) + private val analyzer = new Analyzer(catalog) + + private val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() + ) + + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()), relation)) + assert(plan.isInstanceOf[Project]) + val expr = plan.asInstanceOf[Project].projectList.head + assert(expr.dataType == expectedType) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.dataType == expectedType) + } + + private def stripAlias(expr: Expression): Expression = { + expr match { + case a: Alias => stripAlias(a.child) + case _ => expr + } + } + + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil, relation)) + assert(plan.isInstanceOf[Project]) + val expr = stripAlias(plan.asInstanceOf[Project].projectList.head) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.isInstanceOf[GenericExpressionTransformer]) + val binaryComparison = transformedExpr.asInstanceOf[GenericExpressionTransformer] + assert(binaryComparison.original.isInstanceOf[BinaryComparison]) + assert(binaryComparison.children.size == 2) + assert(binaryComparison.children.forall(_.dataType == expectedType)) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + } + + test("Comparison operations") { + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) + } + + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) + + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + } + } +} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala index a3720fc62..c27159ceb 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import com.google.common.collect.Lists -class CustomAddExpressionTransformer( +case class CustomAddExpressionTransformer( substraitExprName: String, left: ExpressionTransformer, right: ExpressionTransformer, original: Expression) - extends ExpressionTransformer + extends ExpressionTransformerWithOrigin with Logging { override def doTransform(args: java.lang.Object): ExpressionNode = { val leftNode = left.doTransform(args) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 6a403204f..85f3f94cc 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -800,6 +800,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeGlutenTest("to_unix_timestamp") .excludeGlutenTest("Hour") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] .exclude("sha2") .exclude("murmur3/xxHash64/hive hash: struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 128e52a79..1d796aa1b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenAnsiCastSuiteWithAnsiModeOff, GlutenAnsiCastSuiteWithAnsiModeOn, GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCastSuiteWithAnsiModeOn, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExp [...] +import org.apache.spark.sql.catalyst.expressions.{GlutenAnsiCastSuiteWithAnsiModeOff, GlutenAnsiCastSuiteWithAnsiModeOn, GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCastSuiteWithAnsiModeOn, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpre [...] import org.apache.spark.sql.connector._ import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution._ @@ -141,6 +141,7 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("from_unixtime") .exclude("test timestamp add") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala new file mode 100644 index 000000000..97e752d7d --- /dev/null +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.gluten.expression._ + +import org.apache.spark.sql.GlutenTestsTrait +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types._ + +class GlutenDecimalPrecisionSuite extends GlutenTestsTrait { + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry) + private val analyzer = new Analyzer(catalog) + + private val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() + ) + + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()), relation)) + assert(plan.isInstanceOf[Project]) + val expr = plan.asInstanceOf[Project].projectList.head + assert(expr.dataType == expectedType) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.dataType == expectedType) + } + + private def stripAlias(expr: Expression): Expression = { + expr match { + case a: Alias => stripAlias(a.child) + case _ => expr + } + } + + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil, relation)) + assert(plan.isInstanceOf[Project]) + val expr = stripAlias(plan.asInstanceOf[Project].projectList.head) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.isInstanceOf[GenericExpressionTransformer]) + val binaryComparison = transformedExpr.asInstanceOf[GenericExpressionTransformer] + assert(binaryComparison.original.isInstanceOf[BinaryComparison]) + assert(binaryComparison.children.size == 2) + assert(binaryComparison.children.forall(_.dataType == expectedType)) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + } + + test("Comparison operations") { + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) + } + + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) + + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + } + } +} diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 37e4c68f7..069d697bd 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -639,6 +639,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeGlutenTest("to_unix_timestamp") .excludeGlutenTest("Hour") enableSuite[GlutenDecimalExpressionSuite].exclude("MakeDecimal") + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] .exclude("sha2") .exclude("murmur3/xxHash64/hive hash: struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 6ea29847b..7c8509f80 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, Glu [...] +import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, Glu [...] import org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite, GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite, GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter, GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite, GlutenDeleteFromTableSuite, GlutenDeltaBasedDeleteFromTableSuite, GlutenFileDataSourceV2FallBackSuite, GlutenGroupBasedDeleteFromTableSuite, GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, G [...] import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite} @@ -121,6 +121,7 @@ class VeloxTestSettings extends BackendTestSettings { // Replaced by a gluten test to pass timezone through config. .exclude("from_unixtime") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala new file mode 100644 index 000000000..97e752d7d --- /dev/null +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.gluten.expression._ + +import org.apache.spark.sql.GlutenTestsTrait +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types._ + +class GlutenDecimalPrecisionSuite extends GlutenTestsTrait { + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry) + private val analyzer = new Analyzer(catalog) + + private val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() + ) + + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()), relation)) + assert(plan.isInstanceOf[Project]) + val expr = plan.asInstanceOf[Project].projectList.head + assert(expr.dataType == expectedType) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.dataType == expectedType) + } + + private def stripAlias(expr: Expression): Expression = { + expr match { + case a: Alias => stripAlias(a.child) + case _ => expr + } + } + + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil, relation)) + assert(plan.isInstanceOf[Project]) + val expr = stripAlias(plan.asInstanceOf[Project].projectList.head) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.isInstanceOf[GenericExpressionTransformer]) + val binaryComparison = transformedExpr.asInstanceOf[GenericExpressionTransformer] + assert(binaryComparison.original.isInstanceOf[BinaryComparison]) + assert(binaryComparison.children.size == 2) + assert(binaryComparison.children.forall(_.dataType == expectedType)) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + } + + test("Comparison operations") { + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) + } + + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) + + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + } + } +} diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 37e4c68f7..069d697bd 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -639,6 +639,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeGlutenTest("to_unix_timestamp") .excludeGlutenTest("Hour") enableSuite[GlutenDecimalExpressionSuite].exclude("MakeDecimal") + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] .exclude("sha2") .exclude("murmur3/xxHash64/hive hash: struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index e6e42acb3..40ecc3c35 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, Glu [...] +import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, Glu [...] import org.apache.spark.sql.connector._ import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution._ @@ -122,6 +122,7 @@ class VeloxTestSettings extends BackendTestSettings { // Replaced by a gluten test to pass timezone through config. .exclude("from_unixtime") enableSuite[GlutenDecimalExpressionSuite] + enableSuite[GlutenDecimalPrecisionSuite] enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala new file mode 100644 index 000000000..97e752d7d --- /dev/null +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.gluten.expression._ + +import org.apache.spark.sql.GlutenTestsTrait +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types._ + +class GlutenDecimalPrecisionSuite extends GlutenTestsTrait { + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry) + private val analyzer = new Analyzer(catalog) + + private val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() + ) + + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()), relation)) + assert(plan.isInstanceOf[Project]) + val expr = plan.asInstanceOf[Project].projectList.head + assert(expr.dataType == expectedType) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.dataType == expectedType) + } + + private def stripAlias(expr: Expression): Expression = { + expr match { + case a: Alias => stripAlias(a.child) + case _ => expr + } + } + + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil, relation)) + assert(plan.isInstanceOf[Project]) + val expr = stripAlias(plan.asInstanceOf[Project].projectList.head) + val transformedExpr = + ExpressionConverter.replaceWithExpressionTransformer(expr, plan.inputSet.toSeq) + assert(transformedExpr.isInstanceOf[GenericExpressionTransformer]) + val binaryComparison = transformedExpr.asInstanceOf[GenericExpressionTransformer] + assert(binaryComparison.original.isInstanceOf[BinaryComparison]) + assert(binaryComparison.children.size == 2) + assert(binaryComparison.children.forall(_.dataType == expectedType)) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + } + + test("Comparison operations") { + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) + } + + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) + + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org