This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 7e0ff1a46 Chore: Refactor serde for math expressions (#2259)
7e0ff1a46 is described below
commit 7e0ff1a468162fd768471d6d822c1fbd7c98daef
Author: Kazantsev Maksim <[email protected]>
AuthorDate: Fri Aug 29 12:56:23 2025 -0700
Chore: Refactor serde for math expressions (#2259)
* Maths expr refactor
* Fix
* Format
---------
Co-authored-by: Kazantsev Maksim <[email protected]>
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 79 ++------------
.../main/scala/org/apache/comet/serde/math.scala | 120 +++++++++++++++++++++
2 files changed, 128 insertions(+), 71 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index bef7e15e1..22bd6fd03 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -170,7 +170,14 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[DateSub] -> CometDateSub,
classOf[TruncDate] -> CometTruncDate,
classOf[TruncTimestamp] -> CometTruncTimestamp,
- classOf[Flatten] -> CometFlatten)
+ classOf[Flatten] -> CometFlatten,
+ classOf[Atan2] -> CometAtan2,
+ classOf[Ceil] -> CometCeil,
+ classOf[Floor] -> CometFloor,
+ classOf[Log] -> CometLog,
+ classOf[Log10] -> CometLog10,
+ classOf[Log2] -> CometLog2,
+ classOf[Pow] -> CometScalarFunction[Pow]("pow"))
/**
* Mapping of Spark aggregate expression class to Comet expression handler.
@@ -1108,12 +1115,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
// None
// }
- case Atan2(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs, binding)
- val rightExpr = exprToProtoInternal(right, inputs, binding)
- val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
- optExprWithInfo(optExpr, expr, left, right)
-
case Hex(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr =
@@ -1131,56 +1132,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
scalarFunctionExprToProtoWithReturnType("unhex", e.dataType,
childExpr, failOnErrorExpr)
optExprWithInfo(optExpr, expr, unHex._1)
- case e @ Ceil(child) =>
- val childExpr = exprToProtoInternal(child, inputs, binding)
- child.dataType match {
- case t: DecimalType if t.scale == 0 => // zero scale is no-op
- childExpr
- case t: DecimalType if t.scale < 0 => // Spark disallows negative
scale SPARK-30252
- withInfo(e, s"Decimal type $t has negative scale")
- None
- case _ =>
- val optExpr = scalarFunctionExprToProtoWithReturnType("ceil",
e.dataType, childExpr)
- optExprWithInfo(optExpr, expr, child)
- }
-
- case e @ Floor(child) =>
- val childExpr = exprToProtoInternal(child, inputs, binding)
- child.dataType match {
- case t: DecimalType if t.scale == 0 => // zero scale is no-op
- childExpr
- case t: DecimalType if t.scale < 0 => // Spark disallows negative
scale SPARK-30252
- withInfo(e, s"Decimal type $t has negative scale")
- None
- case _ =>
- val optExpr = scalarFunctionExprToProtoWithReturnType("floor",
e.dataType, childExpr)
- optExprWithInfo(optExpr, expr, child)
- }
-
- // The expression for `log` functions is defined as null on numbers less
than or equal
- // to 0. This matches Spark and Hive behavior, where non positive values
eval to null
- // instead of NaN or -Infinity.
- case Log(child) =>
- val childExpr = exprToProtoInternal(nullIfNegative(child), inputs,
binding)
- val optExpr = scalarFunctionExprToProto("ln", childExpr)
- optExprWithInfo(optExpr, expr, child)
-
- case Log10(child) =>
- val childExpr = exprToProtoInternal(nullIfNegative(child), inputs,
binding)
- val optExpr = scalarFunctionExprToProto("log10", childExpr)
- optExprWithInfo(optExpr, expr, child)
-
- case Log2(child) =>
- val childExpr = exprToProtoInternal(nullIfNegative(child), inputs,
binding)
- val optExpr = scalarFunctionExprToProto("log2", childExpr)
- optExprWithInfo(optExpr, expr, child)
-
- case Pow(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs, binding)
- val rightExpr = exprToProtoInternal(right, inputs, binding)
- val optExpr = scalarFunctionExprToProto("pow", leftExpr, rightExpr)
- optExprWithInfo(optExpr, expr, left, right)
-
case RegExpReplace(subject, pattern, replacement, startPosition) =>
if (!RegExp.isSupportedPattern(pattern.toString) &&
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
@@ -1265,15 +1216,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
None
}
- case BitwiseAnd(left, right) =>
- createBinaryExpr(
- expr,
- left,
- right,
- inputs,
- binding,
- (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
-
case n @ Not(In(_, _)) =>
CometNotIn.convert(n, inputs, binding)
@@ -1611,11 +1553,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
}
- private def nullIfNegative(expression: Expression): Expression = {
- val zero = Literal.default(expression.dataType)
- If(LessThanOrEqual(expression, zero), Literal.create(null,
expression.dataType), expression)
- }
-
/**
* Returns true if given datatype is supported as a key in DataFusion sort
merge join.
*/
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala
b/spark/src/main/scala/org/apache/comet/serde/math.scala
new file mode 100644
index 000000000..700b9bd44
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil,
Expression, Floor, If, LessThanOrEqual, Literal, Log, Log10, Log2}
+import org.apache.spark.sql.types.DecimalType
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto,
scalarFunctionExprToProtoWithReturnType}
+
+object CometAtan2 extends CometExpressionSerde[Atan2] {
+ override def convert(
+ expr: Atan2,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
+ val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
+ val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
+ optExprWithInfo(optExpr, expr, expr.left, expr.right)
+ }
+}
+
+object CometCeil extends CometExpressionSerde[Ceil] {
+ override def convert(
+ expr: Ceil,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+ expr.child.dataType match {
+ case t: DecimalType if t.scale == 0 => // zero scale is no-op
+ childExpr
+ case t: DecimalType if t.scale < 0 => // Spark disallows negative scale
SPARK-30252
+ withInfo(expr, s"Decimal type $t has negative scale")
+ None
+ case _ =>
+ val optExpr = scalarFunctionExprToProtoWithReturnType("ceil",
expr.dataType, childExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+ }
+ }
+}
+
+object CometFloor extends CometExpressionSerde[Floor] {
+ override def convert(
+ expr: Floor,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+ expr.child.dataType match {
+ case t: DecimalType if t.scale == 0 => // zero scale is no-op
+ childExpr
+ case t: DecimalType if t.scale < 0 => // Spark disallows negative scale
SPARK-30252
+ withInfo(expr, s"Decimal type $t has negative scale")
+ None
+ case _ =>
+ val optExpr = scalarFunctionExprToProtoWithReturnType("floor",
expr.dataType, childExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+ }
+ }
+}
+
+// The expression for `log` functions is defined as null on numbers less than
or equal
+// to 0. This matches Spark and Hive behavior, where non positive values eval
to null
+// instead of NaN or -Infinity.
+object CometLog extends CometExpressionSerde[Log] with MathExprBase {
+ override def convert(
+ expr: Log,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs,
binding)
+ val optExpr = scalarFunctionExprToProto("ln", childExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+ }
+}
+
+object CometLog10 extends CometExpressionSerde[Log10] with MathExprBase {
+ override def convert(
+ expr: Log10,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs,
binding)
+ val optExpr = scalarFunctionExprToProto("log10", childExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+ }
+}
+
+object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase {
+ override def convert(
+ expr: Log2,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs,
binding)
+ val optExpr = scalarFunctionExprToProto("log2", childExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+
+ }
+}
+
+sealed trait MathExprBase {
+ protected def nullIfNegative(expression: Expression): Expression = {
+ val zero = Literal.default(expression.dataType)
+ If(LessThanOrEqual(expression, zero), Literal.create(null,
expression.dataType), expression)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]