This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c3f6714d0 chore: Remove redundant shims for getFailOnError (#1608)
c3f6714d0 is described below
commit c3f6714d07c359565647ac2109ac22397937974d
Author: Andy Grove <[email protected]>
AuthorDate: Fri Apr 4 09:50:09 2025 -0600
chore: Remove redundant shims for getFailOnError (#1608)
* remove redundant shim for average failOnError
* remove redundant shims
* remove redundant shims
* revert rename
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 18 ++++-----
.../scala/org/apache/comet/serde/aggregates.scala | 9 ++---
.../apache/comet/shims/ShimQueryPlanSerde.scala | 47 ----------------------
.../apache/comet/shims/ShimQueryPlanSerde.scala | 34 ----------------
4 files changed, 13 insertions(+), 95 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 50b418737..53f96f445 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -51,12 +51,12 @@ import org.apache.comet.expressions._
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType =>
ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType._
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, BuildSide, JoinType, Operator}
-import org.apache.comet.shims.{CometExprShim, ShimQueryPlanSerde}
+import org.apache.comet.shims.CometExprShim
/**
* An utility object for query plan and expression serialization.
*/
-object QueryPlanSerde extends Logging with ShimQueryPlanSerde with
CometExprShim {
+object QueryPlanSerde extends Logging with CometExprShim {
def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
}
@@ -564,7 +564,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
add.dataType,
- getFailOnError(add),
+ add.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setAdd(mathExpr))
case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
@@ -579,7 +579,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
sub.dataType,
- getFailOnError(sub),
+ sub.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setSubtract(mathExpr))
case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
@@ -594,7 +594,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
mul.dataType,
- getFailOnError(mul),
+ mul.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setMultiply(mathExpr))
case mul @ Multiply(left, _, _) =>
@@ -616,7 +616,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
div.dataType,
- getFailOnError(div),
+ div.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setDivide(mathExpr))
case div @ Divide(left, _, _) =>
@@ -643,7 +643,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
dataType,
- getFailOnError(div),
+ div.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))
if (divideExpr.isDefined) {
@@ -651,7 +651,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
// check overflow for decimal type
val builder = ExprOuterClass.CheckOverflow.newBuilder()
builder.setChild(divideExpr.get)
- builder.setFailOnError(getFailOnError(div))
+ builder.setFailOnError(div.evalMode == EvalMode.ANSI)
builder.setDatatype(serializeDataType(dataType).get)
Some(
ExprOuterClass.Expr
@@ -684,7 +684,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs,
binding,
rem.dataType,
- getFailOnError(rem),
+ rem.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setRemainder(mathExpr))
case rem @ Remainder(left, _, _) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 0284e553c..5f41364c4 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.types.{ByteType, DecimalType,
IntegerType, LongType,
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
-import org.apache.comet.shims.ShimQueryPlanSerde
object CometMin extends CometAggregateExpressionSerde {
@@ -126,7 +125,7 @@ object CometCount extends CometAggregateExpressionSerde {
}
}
-object CometAverage extends CometAggregateExpressionSerde with
ShimQueryPlanSerde {
+object CometAverage extends CometAggregateExpressionSerde {
override def convert(
aggExpr: AggregateExpression,
expr: Expression,
@@ -164,7 +163,7 @@ object CometAverage extends CometAggregateExpressionSerde
with ShimQueryPlanSerd
val builder = ExprOuterClass.Avg.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
- builder.setFailOnError(getFailOnError(avg))
+ builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
builder.setSumDatatype(sumDataType.get)
Some(
@@ -181,7 +180,7 @@ object CometAverage extends CometAggregateExpressionSerde
with ShimQueryPlanSerd
}
}
}
-object CometSum extends CometAggregateExpressionSerde with ShimQueryPlanSerde {
+object CometSum extends CometAggregateExpressionSerde {
override def convert(
aggExpr: AggregateExpression,
expr: Expression,
@@ -207,7 +206,7 @@ object CometSum extends CometAggregateExpressionSerde with
ShimQueryPlanSerde {
val builder = ExprOuterClass.Sum.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
- builder.setFailOnError(getFailOnError(sum))
+ builder.setFailOnError(sum.evalMode == EvalMode.ANSI)
Some(
ExprOuterClass.AggExpr
diff --git
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimQueryPlanSerde.scala
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimQueryPlanSerde.scala
deleted file mode 100644
index c47b399cf..000000000
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimQueryPlanSerde.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.comet.shims
-
-import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
-
-trait ShimQueryPlanSerde {
- def getFailOnError(b: BinaryArithmetic): Boolean =
- b.getClass.getMethod("failOnError").invoke(b).asInstanceOf[Boolean]
-
- def getFailOnError(aggregate: DeclarativeAggregate): Boolean = {
- val failOnError = aggregate.getClass.getDeclaredMethods.flatMap(m =>
- m.getName match {
- case "failOnError" | "useAnsiAdd" =>
Some(m.invoke(aggregate).asInstanceOf[Boolean])
- case _ => None
- })
- if (failOnError.isEmpty) {
- aggregate.getClass.getDeclaredMethods
- .flatMap(m =>
- m.getName match {
- case "initQueryContext" =>
Some(m.invoke(aggregate).asInstanceOf[Option[_]].isDefined)
- case _ => None
- })
- .head
- } else {
- failOnError.head
- }
- }
-}
diff --git
a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
deleted file mode 100644
index 10821881b..000000000
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.comet.shims
-
-import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic,
BinaryExpression, BloomFilterMightContain}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
-
-trait ShimQueryPlanSerde {
- protected def getFailOnError(b: BinaryArithmetic): Boolean =
- b.getClass.getMethod("failOnError").invoke(b).asInstanceOf[Boolean]
-
- protected def getFailOnError(aggregate: Sum): Boolean =
aggregate.initQueryContext().isDefined
- protected def getFailOnError(aggregate: Average): Boolean =
aggregate.initQueryContext().isDefined
-
- protected def isBloomFilterMightContain(binary: BinaryExpression): Boolean =
- binary.isInstanceOf[BloomFilterMightContain]
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]