This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 9bfb5c4b5 chore: Refactor operator serde - part 1 (#2738)
9bfb5c4b5 is described below
commit 9bfb5c4b5c02e1d798006cdfa2faea5c9aa58d8b
Author: Andy Grove <[email protected]>
AuthorDate: Sat Nov 8 11:19:17 2025 -0700
chore: Refactor operator serde - part 1 (#2738)
---
.../org/apache/comet/serde/CometAggregate.scala | 193 ++++++
.../serde/CometAggregateExpressionSerde.scala | 67 ++
.../scala/org/apache/comet/serde/CometExpand.scala | 60 ++
.../apache/comet/serde/CometExpressionSerde.scala | 66 ++
.../scala/org/apache/comet/serde/CometFilter.scala | 51 ++
.../org/apache/comet/serde/CometGlobalLimit.scala | 49 ++
.../org/apache/comet/serde/CometHashJoin.scala | 123 ++++
.../org/apache/comet/serde/CometLocalLimit.scala | 50 ++
.../org/apache/comet/serde/CometNativeScan.scala | 218 ++++++
.../apache/comet/serde/CometOperatorSerde.scala | 57 ++
.../apache/comet/serde/CometScalarFunction.scala | 34 +
.../apache/comet/serde/CometSortMergeJoin.scala | 144 ++++
.../scala/org/apache/comet/serde/CometWindow.scala | 120 ++++
.../org/apache/comet/serde/QueryPlanSerde.scala | 750 +--------------------
14 files changed, 1261 insertions(+), 721 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala
b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala
new file mode 100644
index 000000000..f0cf244f1
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
+import org.apache.spark.sql.types.MapType
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, Operator}
+import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
+
+trait CometBaseAggregate {
+
+ def doConvert(
+ aggregate: BaseAggregateExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ val groupingExpressions = aggregate.groupingExpressions
+ val aggregateExpressions = aggregate.aggregateExpressions
+ val aggregateAttributes = aggregate.aggregateAttributes
+ val resultExpressions = aggregate.resultExpressions
+ val child = aggregate.child
+
+ if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
+ withInfo(aggregate, "No group by or aggregation")
+ return None
+ }
+
+ // Aggregate expressions with filter are not supported yet.
+ if (aggregateExpressions.exists(_.filter.isDefined)) {
+ withInfo(aggregate, "Aggregate expression with filter is not supported")
+ return None
+ }
+
+ if (groupingExpressions.exists(expr =>
+ expr.dataType match {
+ case _: MapType => true
+ case _ => false
+ })) {
+ withInfo(aggregate, "Grouping on map types is not supported")
+ return None
+ }
+
+ val groupingExprsWithInput =
+ groupingExpressions.map(expr => expr.name -> exprToProto(expr,
child.output))
+
+ val emptyExprs = groupingExprsWithInput.collect {
+ case (expr, proto) if proto.isEmpty => expr
+ }
+
+ if (emptyExprs.nonEmpty) {
+ withInfo(aggregate, s"Unsupported group expressions:
${emptyExprs.mkString(", ")}")
+ return None
+ }
+
+ val groupingExprs = groupingExprsWithInput.map(_._2)
+
+ // In some of the cases, the aggregateExpressions could be empty.
+ // For example, if the aggregate functions only have group by or if the
aggregate
+ // functions only have distinct aggregate functions:
+ //
+ // SELECT COUNT(distinct col2), col1 FROM test group by col1
+ // +- HashAggregate (keys =[col1# 6], functions =[count (distinct
col2#7)] )
+ // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS,
[plan_id = 36]
+ // +- HashAggregate (keys =[col1#6], functions =[partial_count
(distinct col2#7)] )
+ // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+ // +- Exchange hashpartitioning (col1#6, col2#7, 10),
ENSURE_REQUIREMENTS, ...
+ // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+ // +- FileScan parquet spark_catalog.default.test[col1#6,
col2#7] ......
+ // If the aggregateExpressions is empty, we only want to build
groupingExpressions,
+ // and skip processing of aggregateExpressions.
+ if (aggregateExpressions.isEmpty) {
+ val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
+ hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
+ val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
+ val resultExprs = resultExpressions.map(exprToProto(_, attributes))
+ if (resultExprs.exists(_.isEmpty)) {
+ withInfo(
+ aggregate,
+ s"Unsupported result expressions found in: $resultExpressions",
+ resultExpressions: _*)
+ return None
+ }
+ hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
+ Some(builder.setHashAgg(hashAggBuilder).build())
+ } else {
+ val modes = aggregateExpressions.map(_.mode).distinct
+
+ if (modes.size != 1) {
+ // This shouldn't happen as all aggregation expressions should share
the same mode.
+ // Fallback to Spark nevertheless here.
+ withInfo(aggregate, "All aggregate expressions do not have the same
mode")
+ return None
+ }
+
+ val mode = modes.head match {
+ case Partial => CometAggregateMode.Partial
+ case Final => CometAggregateMode.Final
+ case _ =>
+ withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
+ return None
+ }
+
+ // In final mode, the aggregate expressions are bound to the output of
the
+ // child and partial aggregate expressions buffer attributes produced by
partial
+ // aggregation. This is done in Spark `HashAggregateExec` internally. In
Comet,
+ // we don't have to do this because we don't use the merging expression.
+ val binding = mode != CometAggregateMode.Final
+ // `output` is only used when `binding` is true (i.e., non-Final)
+ val output = child.output
+
+ val aggExprs =
+ aggregateExpressions.map(aggExprToProto(_, output, binding,
aggregate.conf))
+ if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
+ aggExprs.forall(_.isDefined)) {
+ val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
+ hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
+ hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
+ if (mode == CometAggregateMode.Final) {
+ val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
+ val resultExprs = resultExpressions.map(exprToProto(_, attributes))
+ if (resultExprs.exists(_.isEmpty)) {
+ withInfo(
+ aggregate,
+ s"Unsupported result expressions found in: $resultExpressions",
+ resultExpressions: _*)
+ return None
+ }
+ hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
+ }
+ hashAggBuilder.setModeValue(mode.getNumber)
+ Some(builder.setHashAgg(hashAggBuilder).build())
+ } else {
+ val allChildren: Seq[Expression] =
+ groupingExpressions ++ aggregateExpressions ++ aggregateAttributes
+ withInfo(aggregate, allChildren: _*)
+ None
+ }
+ }
+
+ }
+
+}
+
+object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with
CometBaseAggregate {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
+ CometConf.COMET_EXEC_AGGREGATE_ENABLED)
+
+ override def convert(
+ aggregate: HashAggregateExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ doConvert(aggregate, builder, childOp: _*)
+ }
+}
+
+object CometObjectHashAggregate
+ extends CometOperatorSerde[ObjectHashAggregateExec]
+ with CometBaseAggregate {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
+ CometConf.COMET_EXEC_AGGREGATE_ENABLED)
+
+ override def convert(
+ aggregate: ObjectHashAggregateExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ doConvert(aggregate, builder, childOp: _*)
+ }
+}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
new file mode 100644
index 000000000..c0c2b0728
--- /dev/null
+++
b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateFunction}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Trait for providing serialization logic for aggregate expressions.
+ */
+trait CometAggregateExpressionSerde[T <: AggregateFunction] {
+
+ /**
+ * Get a short name for the expression that can be used as part of a config
key related to the
+ * expression, such as enabling or disabling that expression.
+ *
+ * @param expr
+ * The Spark expression.
+ * @return
+ * Short name for the expression, defaulting to the Spark class name
+ */
+ def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
+
+ /**
+ * Convert a Spark expression into a protocol buffer representation that can
be passed into
+ * native code.
+ *
+ * @param aggExpr
+ * The aggregate expression.
+ * @param expr
+ * The aggregate function.
+ * @param inputs
+ * The input attributes.
+ * @param binding
+ * Whether the attributes are bound (this is only relevant in aggregate
expressions).
+ * @param conf
+ * SQLConf
+ * @return
+ * Protocol buffer representation, or None if the expression could not be
converted. In this
+ * case it is expected that the input expression will have been tagged
with reasons why it
+ * could not be converted.
+ */
+ def convert(
+ aggExpr: AggregateExpression,
+ expr: T,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr]
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala
b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala
new file mode 100644
index 000000000..5979eed4d
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution.ExpandExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+object CometExpand extends CometOperatorSerde[ExpandExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
+ CometConf.COMET_EXEC_EXPAND_ENABLED)
+
+ override def convert(
+ op: ExpandExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ var allProjExprs: Seq[Expression] = Seq()
+ val projExprs = op.projections.flatMap(_.map(e => {
+ allProjExprs = allProjExprs :+ e
+ exprToProto(e, op.child.output)
+ }))
+
+ if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
+ val expandBuilder = OperatorOuterClass.Expand
+ .newBuilder()
+ .addAllProjectList(projExprs.map(_.get).asJava)
+ .setNumExprPerProject(op.projections.head.size)
+ Some(builder.setExpand(expandBuilder).build())
+ } else {
+ withInfo(op, allProjExprs: _*)
+ None
+ }
+
+ }
+
+}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala
new file mode 100644
index 000000000..20c034303
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+
+/**
+ * Trait for providing serialization logic for expressions.
+ */
+trait CometExpressionSerde[T <: Expression] {
+
+ /**
+ * Get a short name for the expression that can be used as part of a config
key related to the
+ * expression, such as enabling or disabling that expression.
+ *
+ * @param expr
+ * The Spark expression.
+ * @return
+ * Short name for the expression, defaulting to the Spark class name
+ */
+ def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
+
+ /**
+ * Determine the support level of the expression based on its attributes.
+ *
+ * @param expr
+ * The Spark expression.
+ * @return
+ * Support level (Compatible, Incompatible, or Unsupported).
+ */
+ def getSupportLevel(expr: T): SupportLevel = Compatible(None)
+
+ /**
+ * Convert a Spark expression into a protocol buffer representation that can
be passed into
+ * native code.
+ *
+ * @param expr
+ * The Spark expression.
+ * @param inputs
+ * The input attributes.
+ * @param binding
+ * Whether the attributes are bound (this is only relevant in aggregate
expressions).
+ * @return
+ * Protocol buffer representation, or None if the expression could not be
converted. In this
+ * case it is expected that the input expression will have been tagged
with reasons why it
+ * could not be converted.
+ */
+ def convert(expr: T, inputs: Seq[Attribute], binding: Boolean):
Option[ExprOuterClass.Expr]
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala
b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala
new file mode 100644
index 000000000..1638750b5
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.execution.FilterExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+object CometFilter extends CometOperatorSerde[FilterExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_FILTER_ENABLED)
+
+ override def convert(
+ op: FilterExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ val cond = exprToProto(op.condition, op.child.output)
+
+ if (cond.isDefined && childOp.nonEmpty) {
+ val filterBuilder = OperatorOuterClass.Filter
+ .newBuilder()
+ .setPredicate(cond.get)
+ Some(builder.setFilter(filterBuilder).build())
+ } else {
+ withInfo(op, op.condition, op.child)
+ None
+ }
+ }
+
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala
b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala
new file mode 100644
index 000000000..774e1ad77
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.execution.GlobalLimitExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+
+object CometGlobalLimit extends CometOperatorSerde[GlobalLimitExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED)
+
+ override def convert(
+ op: GlobalLimitExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ if (childOp.nonEmpty) {
+ val limitBuilder = OperatorOuterClass.Limit.newBuilder()
+
+ limitBuilder.setLimit(op.limit).setOffset(op.offset)
+
+ Some(builder.setLimit(limitBuilder).build())
+ } else {
+ withInfo(op, "No child operator")
+ None
+ }
+
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala
b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala
new file mode 100644
index 000000000..67fb67a2e
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti,
LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec}
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.{BuildSide, JoinType,
Operator}
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+trait CometHashJoin {
+
+ def doConvert(
+ join: HashJoin,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ // `HashJoin` has only two implementations in Spark, but we check the type
of the join to
+ // make sure we are handling the correct join type.
+ if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(join.conf) &&
+ join.isInstanceOf[ShuffledHashJoinExec]) &&
+ !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(join.conf) &&
+ join.isInstanceOf[BroadcastHashJoinExec])) {
+ withInfo(join, s"Invalid hash join type ${join.nodeName}")
+ return None
+ }
+
+ if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
+ // https://github.com/apache/datafusion-comet/issues/457
+ withInfo(join, "BuildRight with LeftAnti is not supported")
+ return None
+ }
+
+ val condition = join.condition.map { cond =>
+ val condProto = exprToProto(cond, join.left.output ++ join.right.output)
+ if (condProto.isEmpty) {
+ withInfo(join, cond)
+ return None
+ }
+ condProto.get
+ }
+
+ val joinType = join.joinType match {
+ case Inner => JoinType.Inner
+ case LeftOuter => JoinType.LeftOuter
+ case RightOuter => JoinType.RightOuter
+ case FullOuter => JoinType.FullOuter
+ case LeftSemi => JoinType.LeftSemi
+ case LeftAnti => JoinType.LeftAnti
+ case _ =>
+ // Spark doesn't support other join types
+ withInfo(join, s"Unsupported join type ${join.joinType}")
+ return None
+ }
+
+ val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
+ val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
+
+ if (leftKeys.forall(_.isDefined) &&
+ rightKeys.forall(_.isDefined) &&
+ childOp.nonEmpty) {
+ val joinBuilder = OperatorOuterClass.HashJoin
+ .newBuilder()
+ .setJoinType(joinType)
+ .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
+ .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
+ .setBuildSide(
+ if (join.buildSide == BuildLeft) BuildSide.BuildLeft else
BuildSide.BuildRight)
+ condition.foreach(joinBuilder.setCondition)
+ Some(builder.setHashJoin(joinBuilder).build())
+ } else {
+ val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
+ withInfo(join, allExprs: _*)
+ None
+ }
+ }
+}
+
+object CometBroadcastHashJoin extends CometOperatorSerde[HashJoin] with
CometHashJoin {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED)
+
+ override def convert(
+ join: HashJoin,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[Operator] =
+ doConvert(join, builder, childOp: _*)
+}
+
+object CometShuffleHashJoin extends CometOperatorSerde[HashJoin] with
CometHashJoin {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED)
+
+ override def convert(
+ join: HashJoin,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[Operator] =
+ doConvert(join, builder, childOp: _*)
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala
b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala
new file mode 100644
index 000000000..1347b1290
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.execution.LocalLimitExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+
+object CometLocalLimit extends CometOperatorSerde[LocalLimitExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED)
+
+ override def convert(
+ op: LocalLimitExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ if (childOp.nonEmpty) {
+ // LocalLimit doesn't use offset, but it shares same operator serde
class.
+ // Just set it to zero.
+ val limitBuilder = OperatorOuterClass.Limit
+ .newBuilder()
+ .setLimit(op.limit)
+ .setOffset(0)
+ Some(builder.setLimit(limitBuilder).build())
+ } else {
+ withInfo(op, "No child operator")
+ None
+ }
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala
b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala
new file mode 100644
index 000000000..476313a9d
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.mutable.ListBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Literal
+import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
+import org.apache.spark.sql.comet.CometScanExec
+import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD,
PartitionedFile}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD,
DataSourceRDDPartition}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{StructField, StructType}
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.objectstore.NativeConfig
+import org.apache.comet.parquet.CometParquetUtils
+import org.apache.comet.serde.ExprOuterClass.Expr
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
+
+object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = None
+
+ override def convert(
+ scan: CometScanExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
+ nativeScanBuilder.setSource(scan.simpleStringWithNodeId())
+
+ val scanTypes = scan.output.flatten { attr =>
+ serializeDataType(attr.dataType)
+ }
+
+ if (scanTypes.length == scan.output.length) {
+ nativeScanBuilder.addAllFields(scanTypes.asJava)
+
+ // Sink operators don't have children
+ builder.clearChildren()
+
+ if (scan.conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) &&
+ CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(scan.conf)) {
+
+ val dataFilters = new ListBuffer[Expr]()
+ for (filter <- scan.dataFilters) {
+ exprToProto(filter, scan.output) match {
+ case Some(proto) => dataFilters += proto
+ case _ =>
+ logWarning(s"Unsupported data filter $filter")
+ }
+ }
+ nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
+ }
+
+ val possibleDefaultValues =
getExistenceDefaultValues(scan.requiredSchema)
+ if (possibleDefaultValues.exists(_ != null)) {
+ // Our schema has default values. Serialize two lists, one with the
default values
+ // and another with the indexes in the schema so the native side can
map missing
+ // columns to these default values.
+ val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex
+ .filter { case (expr, _) => expr != null }
+ .map { case (expr, index) =>
+ // ResolveDefaultColumnsUtil.getExistenceDefaultValues has
evaluated these
+ // expressions and they should now just be literals.
+ (Literal(expr), index.toLong.asInstanceOf[java.lang.Long])
+ }
+ .unzip
+ nativeScanBuilder.addAllDefaultValues(
+ defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava)
+ nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
+ }
+
+ // TODO: modify CometNativeScan to generate the file partitions without
instantiating RDD.
+ var firstPartition: Option[PartitionedFile] = None
+ scan.inputRDD match {
+ case rdd: DataSourceRDD =>
+ val partitions = rdd.partitions
+ partitions.foreach(p => {
+ val inputPartitions =
p.asInstanceOf[DataSourceRDDPartition].inputPartitions
+ inputPartitions.foreach(partition => {
+ if (firstPartition.isEmpty) {
+ firstPartition =
partition.asInstanceOf[FilePartition].files.headOption
+ }
+ partition2Proto(
+ partition.asInstanceOf[FilePartition],
+ nativeScanBuilder,
+ scan.relation.partitionSchema)
+ })
+ })
+ case rdd: FileScanRDD =>
+ rdd.filePartitions.foreach(partition => {
+ if (firstPartition.isEmpty) {
+ firstPartition = partition.files.headOption
+ }
+ partition2Proto(partition, nativeScanBuilder,
scan.relation.partitionSchema)
+ })
+ case _ =>
+ }
+
+ val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
+ val requiredSchema = schema2Proto(scan.requiredSchema.fields)
+ val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
+
+ val dataSchemaIndexes = scan.requiredSchema.fields.map(field => {
+ scan.relation.dataSchema.fieldIndex(field.name)
+ })
+ val partitionSchemaIndexes = Array
+ .range(
+ scan.relation.dataSchema.fields.length,
+ scan.relation.dataSchema.length +
scan.relation.partitionSchema.fields.length)
+
+ val projectionVector = (dataSchemaIndexes ++
partitionSchemaIndexes).map(idx =>
+ idx.toLong.asInstanceOf[java.lang.Long])
+
+
nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
+
+ // In `CometScanRule`, we ensure partitionSchema is supported.
+ assert(partitionSchema.length ==
scan.relation.partitionSchema.fields.length)
+
+ nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
+ nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
+
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
+
nativeScanBuilder.setSessionTimezone(scan.conf.getConfString("spark.sql.session.timeZone"))
+
nativeScanBuilder.setCaseSensitive(scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE))
+
+ // Collect S3/cloud storage configurations
+ val hadoopConf = scan.relation.sparkSession.sessionState
+ .newHadoopConfWithOptions(scan.relation.options)
+
+
nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf))
+
+ firstPartition.foreach { partitionFile =>
+ val objectStoreOptions =
+ NativeConfig.extractObjectStoreOptions(hadoopConf,
partitionFile.pathUri)
+ objectStoreOptions.foreach { case (key, value) =>
+ nativeScanBuilder.putObjectStoreOptions(key, value)
+ }
+ }
+
+ Some(builder.setNativeScan(nativeScanBuilder).build())
+
+ } else {
+ // There are unsupported scan type
+ withInfo(
+ scan,
+ s"unsupported Comet operator: ${scan.nodeName}, due to unsupported
data types above")
+ None
+ }
+
+ }
+
+ private def schema2Proto(
+ fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField]
= {
+ val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
+ fields.map(field => {
+ fieldBuilder.setName(field.name)
+ fieldBuilder.setDataType(serializeDataType(field.dataType).get)
+ fieldBuilder.setNullable(field.nullable)
+ fieldBuilder.build()
+ })
+ }
+
+ private def partition2Proto(
+ partition: FilePartition,
+ nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
+ partitionSchema: StructType): Unit = {
+ val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
+ partition.files.foreach(file => {
+ // Process the partition values
+ val partitionValues = file.partitionValues
+ assert(partitionValues.numFields == partitionSchema.length)
+ val partitionVals =
+ partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value,
i) =>
+ val attr = partitionSchema(i)
+ val valueProto = exprToProto(Literal(value, attr.dataType),
Seq.empty)
+ // In `CometScanRule`, we have already checked that all partition
values are
+ // supported. So, we can safely use `get` here.
+ assert(
+ valueProto.isDefined,
+ s"Unsupported partition value: $value, type: ${attr.dataType}")
+ valueProto.get
+ }
+
+ val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
+ partitionVals.foreach(fileBuilder.addPartitionValues)
+ fileBuilder
+ .setFilePath(file.filePath.toString)
+ .setStart(file.start)
+ .setLength(file.length)
+ .setFileSize(file.fileSize)
+ partitionBuilder.addPartitionedFile(fileBuilder.build())
+ })
+ nativeScanBuilder.addFilePartitions(partitionBuilder.build())
+ }
+
+}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
new file mode 100644
index 000000000..c6a95ec88
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.comet.ConfigEntry
+import org.apache.comet.serde.OperatorOuterClass.Operator
+
+/**
+ * Trait for providing serialization logic for operators.
+ */
+trait CometOperatorSerde[T <: SparkPlan] {
+
+ /**
+ * Convert a Spark operator into a protocol buffer representation that can
be passed into native
+ * code.
+ *
+ * @param op
+ * The Spark operator.
+ * @param builder
+ * The protobuf builder for the operator.
+ * @param childOp
+ * Child operators that have already been converted to Comet.
+ * @return
+ * Protocol buffer representation, or None if the operator could not be
converted. In this
+ * case it is expected that the input operator will have been tagged with
reasons why it could
+ * not be converted.
+ */
+ def convert(
+ op: T,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[OperatorOuterClass.Operator]
+
+ /**
+ * Get the optional Comet configuration entry that is used to enable or
disable native support
+ * for this operator.
+ */
+ def enabledConfig: Option[ConfigEntry[Boolean]]
+}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala
b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala
new file mode 100644
index 000000000..aa3bf775f
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+
+import org.apache.comet.serde.ExprOuterClass.Expr
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto}
+
+/** Serde for scalar function. */
+case class CometScalarFunction[T <: Expression](name: String) extends
CometExpressionSerde[T] {
+ override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean):
Option[Expr] = {
+ val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
+ val optExpr = scalarFunctionExprToProto(name, childExpr: _*)
+ optExprWithInfo(optExpr, expr, expr.children: _*)
+ }
+}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala
b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala
new file mode 100644
index 000000000..5f926f06e
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression,
ExpressionSet, SortOrder}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti,
LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType,
DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType,
StringType, TimestampNTZType}
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.{JoinType, Operator}
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+object CometSortMergeJoin extends CometOperatorSerde[SortMergeJoinExec] {
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
+ CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED)
+
+ override def convert(
+ join: SortMergeJoinExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ // `requiredOrders` and `getKeyOrdering` are copied from Spark's
SortMergeJoinExec.
+ def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+ keys.map(SortOrder(_, Ascending))
+ }
+
+ def getKeyOrdering(
+ keys: Seq[Expression],
+ childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
+ val requiredOrdering = requiredOrders(keys)
+ if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
+ keys.zip(childOutputOrdering).map { case (key, childOrder) =>
+ val sameOrderExpressionsSet = ExpressionSet(childOrder.children) -
key
+ SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
+ }
+ } else {
+ requiredOrdering
+ }
+ }
+
+ if (join.condition.isDefined &&
+ !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED
+ .get(join.conf)) {
+ withInfo(
+ join,
+ s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key}
is not enabled",
+ join.condition.get)
+ return None
+ }
+
+ val condition = join.condition.map { cond =>
+ val condProto = exprToProto(cond, join.left.output ++ join.right.output)
+ if (condProto.isEmpty) {
+ withInfo(join, cond)
+ return None
+ }
+ condProto.get
+ }
+
+ val joinType = join.joinType match {
+ case Inner => JoinType.Inner
+ case LeftOuter => JoinType.LeftOuter
+ case RightOuter => JoinType.RightOuter
+ case FullOuter => JoinType.FullOuter
+ case LeftSemi => JoinType.LeftSemi
+ case LeftAnti => JoinType.LeftAnti
+ case _ =>
+ // Spark doesn't support other join types
+ withInfo(join, s"Unsupported join type ${join.joinType}")
+ return None
+ }
+
+ // Checks if the join keys are supported by DataFusion SortMergeJoin.
+ val errorMsgs = join.leftKeys.flatMap { key =>
+ if (!supportedSortMergeJoinEqualType(key.dataType)) {
+ Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
+ } else {
+ None
+ }
+ }
+
+ if (errorMsgs.nonEmpty) {
+ withInfo(join, errorMsgs.flatten.mkString("\n"))
+ return None
+ }
+
+ val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
+ val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
+
+ val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering)
+ .map(exprToProto(_, join.left.output))
+
+ if (sortOptions.forall(_.isDefined) &&
+ leftKeys.forall(_.isDefined) &&
+ rightKeys.forall(_.isDefined) &&
+ childOp.nonEmpty) {
+ val joinBuilder = OperatorOuterClass.SortMergeJoin
+ .newBuilder()
+ .setJoinType(joinType)
+ .addAllSortOptions(sortOptions.map(_.get).asJava)
+ .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
+ .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
+ condition.map(joinBuilder.setCondition)
+ Some(builder.setSortMergeJoin(joinBuilder).build())
+ } else {
+ val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
+ withInfo(join, allExprs: _*)
+ None
+ }
+
+ }
+
+ /**
+ * Returns true if given datatype is supported as a key in DataFusion sort
merge join.
+ */
+ private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean =
dataType match {
+ case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _:
FloatType |
+ _: DoubleType | _: StringType | _: DateType | _: DecimalType | _:
BooleanType =>
+ true
+ case TimestampNTZType => true
+ case _ => false
+ }
+
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala
b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala
new file mode 100644
index 000000000..7e963d632
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Expression, SortOrder, WindowExpression}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.window.WindowExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, windowExprToProto}
+
+object CometWindow extends CometOperatorSerde[WindowExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
+ CometConf.COMET_EXEC_WINDOW_ENABLED)
+
+ override def convert(
+ op: WindowExec,
+ builder: Operator.Builder,
+ childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+ val output = op.child.output
+
+ val winExprs: Array[WindowExpression] = op.windowExpression.flatMap { expr
=>
+ expr match {
+ case alias: Alias =>
+ alias.child match {
+ case winExpr: WindowExpression =>
+ Some(winExpr)
+ case _ =>
+ None
+ }
+ case _ =>
+ None
+ }
+ }.toArray
+
+ if (winExprs.length != op.windowExpression.length) {
+ withInfo(op, "Unsupported window expression(s)")
+ return None
+ }
+
+ if (op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty &&
+ !validatePartitionAndSortSpecsForWindowFunc(op.partitionSpec,
op.orderSpec, op)) {
+ return None
+ }
+
+ val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
+ val partitionExprs = op.partitionSpec.map(exprToProto(_, op.child.output))
+
+ val sortOrders = op.orderSpec.map(exprToProto(_, op.child.output))
+
+ if (windowExprProto.forall(_.isDefined) &&
partitionExprs.forall(_.isDefined)
+ && sortOrders.forall(_.isDefined)) {
+ val windowBuilder = OperatorOuterClass.Window.newBuilder()
+
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
+ windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
+ windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
+ Some(builder.setWindow(windowBuilder).build())
+ } else {
+ None
+ }
+
+ }
+
+ private def validatePartitionAndSortSpecsForWindowFunc(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ op: SparkPlan): Boolean = {
+ if (partitionSpec.length != orderSpec.length) {
+ return false
+ }
+
+ val partitionColumnNames = partitionSpec.collect {
+ case a: AttributeReference => a.name
+ case other =>
+ withInfo(op, s"Unsupported partition expression:
${other.getClass.getSimpleName}")
+ return false
+ }
+
+ val orderColumnNames = orderSpec.collect { case s: SortOrder =>
+ s.child match {
+ case a: AttributeReference => a.name
+ case other =>
+ withInfo(op, s"Unsupported sort expression:
${other.getClass.getSimpleName}")
+ return false
+ }
+ }
+
+ if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol,
orderCol) =>
+ partCol != orderCol
+ }) {
+ withInfo(op, "Partitioning and sorting specifications must be the same.")
+ return false
+ }
+
+ true
+ }
+
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 3f8de7693..3e0e837c9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -19,39 +19,31 @@
package org.apache.comet.serde
-import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
NormalizeNaNAndZero}
-import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
-import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD,
PartitionedFile}
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD,
DataSourceRDDPartition}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
import org.apache.comet.expressions._
-import org.apache.comet.objectstore.NativeConfig
-import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
-import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, BuildSide, JoinType, Operator}
-import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto}
+import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.Types.{DataType => ProtoDataType}
import org.apache.comet.serde.Types.DataType._
import org.apache.comet.serde.literals.CometLiteral
@@ -911,17 +903,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
}
- /**
- * Returns true if given datatype is supported as a key in DataFusion sort
merge join.
- */
- private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean =
dataType match {
- case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _:
FloatType |
- _: DoubleType | _: StringType | _: DateType | _: DecimalType | _:
BooleanType =>
- true
- case TimestampNTZType => true
- case _ => false
- }
-
/**
* Convert a Spark plan operator to a protobuf Comet operator.
*
@@ -943,514 +924,47 @@ object QueryPlanSerde extends Logging with CometExprShim
{
// Fully native scan for V1
case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_DATAFUSION =>
- val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
- nativeScanBuilder.setSource(op.simpleStringWithNodeId())
-
- val scanTypes = op.output.flatten { attr =>
- serializeDataType(attr.dataType)
- }
-
- if (scanTypes.length == op.output.length) {
- nativeScanBuilder.addAllFields(scanTypes.asJava)
-
- // Sink operators don't have children
- builder.clearChildren()
-
- if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) &&
- CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(conf)) {
-
- val dataFilters = new ListBuffer[Expr]()
- for (filter <- scan.dataFilters) {
- exprToProto(filter, scan.output) match {
- case Some(proto) => dataFilters += proto
- case _ =>
- logWarning(s"Unsupported data filter $filter")
- }
- }
- nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
- }
-
- val possibleDefaultValues =
getExistenceDefaultValues(scan.requiredSchema)
- if (possibleDefaultValues.exists(_ != null)) {
- // Our schema has default values. Serialize two lists, one with
the default values
- // and another with the indexes in the schema so the native side
can map missing
- // columns to these default values.
- val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex
- .filter { case (expr, _) => expr != null }
- .map { case (expr, index) =>
- // ResolveDefaultColumnsUtil.getExistenceDefaultValues has
evaluated these
- // expressions and they should now just be literals.
- (Literal(expr), index.toLong.asInstanceOf[java.lang.Long])
- }
- .unzip
- nativeScanBuilder.addAllDefaultValues(
- defaultValues.flatMap(exprToProto(_,
scan.output)).toIterable.asJava)
-
nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
- }
-
- // TODO: modify CometNativeScan to generate the file partitions
without instantiating RDD.
- var firstPartition: Option[PartitionedFile] = None
- scan.inputRDD match {
- case rdd: DataSourceRDD =>
- val partitions = rdd.partitions
- partitions.foreach(p => {
- val inputPartitions =
p.asInstanceOf[DataSourceRDDPartition].inputPartitions
- inputPartitions.foreach(partition => {
- if (firstPartition.isEmpty) {
- firstPartition =
partition.asInstanceOf[FilePartition].files.headOption
- }
- partition2Proto(
- partition.asInstanceOf[FilePartition],
- nativeScanBuilder,
- scan.relation.partitionSchema)
- })
- })
- case rdd: FileScanRDD =>
- rdd.filePartitions.foreach(partition => {
- if (firstPartition.isEmpty) {
- firstPartition = partition.files.headOption
- }
- partition2Proto(partition, nativeScanBuilder,
scan.relation.partitionSchema)
- })
- case _ =>
- }
-
- val partitionSchema =
schema2Proto(scan.relation.partitionSchema.fields)
- val requiredSchema = schema2Proto(scan.requiredSchema.fields)
- val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
-
- val dataSchemaIndexes = scan.requiredSchema.fields.map(field => {
- scan.relation.dataSchema.fieldIndex(field.name)
- })
- val partitionSchemaIndexes = Array
- .range(
- scan.relation.dataSchema.fields.length,
- scan.relation.dataSchema.length +
scan.relation.partitionSchema.fields.length)
-
- val projectionVector = (dataSchemaIndexes ++
partitionSchemaIndexes).map(idx =>
- idx.toLong.asInstanceOf[java.lang.Long])
-
-
nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
-
- // In `CometScanRule`, we ensure partitionSchema is supported.
- assert(partitionSchema.length ==
scan.relation.partitionSchema.fields.length)
+ CometNativeScan.convert(scan, builder, childOp: _*)
- nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
-
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
-
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
-
nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone"))
-
nativeScanBuilder.setCaseSensitive(conf.getConf[Boolean](SQLConf.CASE_SENSITIVE))
+ case filter: FilterExec if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf)
=>
+ CometFilter.convert(filter, builder, childOp: _*)
- // Collect S3/cloud storage configurations
- val hadoopConf = scan.relation.sparkSession.sessionState
- .newHadoopConfWithOptions(scan.relation.options)
-
-
nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf))
-
- firstPartition.foreach { partitionFile =>
- val objectStoreOptions =
- NativeConfig.extractObjectStoreOptions(hadoopConf,
partitionFile.pathUri)
- objectStoreOptions.foreach { case (key, value) =>
- nativeScanBuilder.putObjectStoreOptions(key, value)
- }
- }
-
- Some(builder.setNativeScan(nativeScanBuilder).build())
-
- } else {
- // There are unsupported scan type
- withInfo(
- op,
- s"unsupported Comet operator: ${op.nodeName}, due to unsupported
data types above")
- None
- }
-
- case FilterExec(condition, child) if
CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
- val cond = exprToProto(condition, child.output)
-
- if (cond.isDefined && childOp.nonEmpty) {
- val filterBuilder = OperatorOuterClass.Filter
- .newBuilder()
- .setPredicate(cond.get)
- Some(builder.setFilter(filterBuilder).build())
- } else {
- withInfo(op, condition, child)
- None
- }
-
- case LocalLimitExec(limit, _) if
CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
- if (childOp.nonEmpty) {
- // LocalLimit doesn't use offset, but it shares same operator serde
class.
- // Just set it to zero.
- val limitBuilder = OperatorOuterClass.Limit
- .newBuilder()
- .setLimit(limit)
- .setOffset(0)
- Some(builder.setLimit(limitBuilder).build())
- } else {
- withInfo(op, "No child operator")
- None
- }
+ case limit: LocalLimitExec if
CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
+ CometLocalLimit.convert(limit, builder, childOp: _*)
case globalLimitExec: GlobalLimitExec
if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) =>
- if (childOp.nonEmpty) {
- val limitBuilder = OperatorOuterClass.Limit.newBuilder()
+ CometGlobalLimit.convert(globalLimitExec, builder, childOp: _*)
-
limitBuilder.setLimit(globalLimitExec.limit).setOffset(globalLimitExec.offset)
+ case expand: ExpandExec if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf)
=>
+ CometExpand.convert(expand, builder, childOp: _*)
- Some(builder.setLimit(limitBuilder).build())
- } else {
- withInfo(op, "No child operator")
- None
- }
-
- case ExpandExec(projections, _, child) if
CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) =>
- var allProjExprs: Seq[Expression] = Seq()
- val projExprs = projections.flatMap(_.map(e => {
- allProjExprs = allProjExprs :+ e
- exprToProto(e, child.output)
- }))
-
- if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
- val expandBuilder = OperatorOuterClass.Expand
- .newBuilder()
- .addAllProjectList(projExprs.map(_.get).asJava)
- .setNumExprPerProject(projections.head.size)
- Some(builder.setExpand(expandBuilder).build())
- } else {
- withInfo(op, allProjExprs: _*)
- None
- }
-
- case WindowExec(windowExpression, partitionSpec, orderSpec, child)
- if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
+ case _: WindowExec if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
withInfo(op, "Window expressions are not supported")
None
- /*
- val output = child.output
-
- val winExprs: Array[WindowExpression] = windowExpression.flatMap {
expr =>
- expr match {
- case alias: Alias =>
- alias.child match {
- case winExpr: WindowExpression =>
- Some(winExpr)
- case _ =>
- None
- }
- case _ =>
- None
- }
- }.toArray
- if (winExprs.length != windowExpression.length) {
- withInfo(op, "Unsupported window expression(s)")
- return None
- }
+ case aggregate: HashAggregateExec if
CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
+ CometHashAggregate.convert(aggregate, builder, childOp: _*)
- if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
- !validatePartitionAndSortSpecsForWindowFunc(partitionSpec,
orderSpec, op)) {
- return None
- }
+ case aggregate: ObjectHashAggregateExec
+ if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
+ CometObjectHashAggregate.convert(aggregate, builder, childOp: _*)
- val windowExprProto = winExprs.map(windowExprToProto(_, output,
op.conf))
- val partitionExprs = partitionSpec.map(exprToProto(_, child.output))
+ case join: BroadcastHashJoinExec
+ if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) =>
+ CometBroadcastHashJoin.convert(join, builder, childOp: _*)
- val sortOrders = orderSpec.map(exprToProto(_, child.output))
+ case join: ShuffledHashJoinExec if
CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
+ CometShuffleHashJoin.convert(join, builder, childOp: _*)
- if (windowExprProto.forall(_.isDefined) &&
partitionExprs.forall(_.isDefined)
- && sortOrders.forall(_.isDefined)) {
- val windowBuilder = OperatorOuterClass.Window.newBuilder()
-
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
- windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
- windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
- Some(builder.setWindow(windowBuilder).build())
+ case join: SortMergeJoinExec =>
+ if (CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf)) {
+ CometSortMergeJoin.convert(join, builder, childOp: _*)
} else {
+ withInfo(join, "SortMergeJoin is not enabled")
None
- } */
-
- case aggregate: BaseAggregateExec
- if (aggregate.isInstanceOf[HashAggregateExec] ||
- aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
- CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
- val groupingExpressions = aggregate.groupingExpressions
- val aggregateExpressions = aggregate.aggregateExpressions
- val aggregateAttributes = aggregate.aggregateAttributes
- val resultExpressions = aggregate.resultExpressions
- val child = aggregate.child
-
- if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
- withInfo(op, "No group by or aggregation")
- return None
}
- // Aggregate expressions with filter are not supported yet.
- if (aggregateExpressions.exists(_.filter.isDefined)) {
- withInfo(op, "Aggregate expression with filter is not supported")
- return None
- }
-
- if (groupingExpressions.exists(expr =>
- expr.dataType match {
- case _: MapType => true
- case _ => false
- })) {
- withInfo(op, "Grouping on map types is not supported")
- return None
- }
-
- val groupingExprsWithInput =
- groupingExpressions.map(expr => expr.name -> exprToProto(expr,
child.output))
-
- val emptyExprs = groupingExprsWithInput.collect {
- case (expr, proto) if proto.isEmpty => expr
- }
-
- if (emptyExprs.nonEmpty) {
- withInfo(op, s"Unsupported group expressions:
${emptyExprs.mkString(", ")}")
- return None
- }
-
- val groupingExprs = groupingExprsWithInput.map(_._2)
-
- // In some of the cases, the aggregateExpressions could be empty.
- // For example, if the aggregate functions only have group by or if
the aggregate
- // functions only have distinct aggregate functions:
- //
- // SELECT COUNT(distinct col2), col1 FROM test group by col1
- // +- HashAggregate (keys =[col1# 6], functions =[count (distinct
col2#7)] )
- // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS,
[plan_id = 36]
- // +- HashAggregate (keys =[col1#6], functions =[partial_count
(distinct col2#7)] )
- // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
- // +- Exchange hashpartitioning (col1#6, col2#7, 10),
ENSURE_REQUIREMENTS, ...
- // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
- // +- FileScan parquet spark_catalog.default.test[col1#6,
col2#7] ......
- // If the aggregateExpressions is empty, we only want to build
groupingExpressions,
- // and skip processing of aggregateExpressions.
- if (aggregateExpressions.isEmpty) {
- val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
- hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
- val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
- val resultExprs = resultExpressions.map(exprToProto(_, attributes))
- if (resultExprs.exists(_.isEmpty)) {
- withInfo(
- op,
- s"Unsupported result expressions found in: $resultExpressions",
- resultExpressions: _*)
- return None
- }
- hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
- Some(builder.setHashAgg(hashAggBuilder).build())
- } else {
- val modes = aggregateExpressions.map(_.mode).distinct
-
- if (modes.size != 1) {
- // This shouldn't happen as all aggregation expressions should
share the same mode.
- // Fallback to Spark nevertheless here.
- withInfo(op, "All aggregate expressions do not have the same mode")
- return None
- }
-
- val mode = modes.head match {
- case Partial => CometAggregateMode.Partial
- case Final => CometAggregateMode.Final
- case _ =>
- withInfo(op, s"Unsupported aggregation mode ${modes.head}")
- return None
- }
-
- // In final mode, the aggregate expressions are bound to the output
of the
- // child and partial aggregate expressions buffer attributes
produced by partial
- // aggregation. This is done in Spark `HashAggregateExec`
internally. In Comet,
- // we don't have to do this because we don't use the merging
expression.
- val binding = mode != CometAggregateMode.Final
- // `output` is only used when `binding` is true (i.e., non-Final)
- val output = child.output
-
- val aggExprs =
- aggregateExpressions.map(aggExprToProto(_, output, binding,
op.conf))
- if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
- aggExprs.forall(_.isDefined)) {
- val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
- hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
- hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
- if (mode == CometAggregateMode.Final) {
- val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
- val resultExprs = resultExpressions.map(exprToProto(_,
attributes))
- if (resultExprs.exists(_.isEmpty)) {
- withInfo(
- op,
- s"Unsupported result expressions found in:
$resultExpressions",
- resultExpressions: _*)
- return None
- }
- hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
- }
- hashAggBuilder.setModeValue(mode.getNumber)
- Some(builder.setHashAgg(hashAggBuilder).build())
- } else {
- val allChildren: Seq[Expression] =
- groupingExpressions ++ aggregateExpressions ++
aggregateAttributes
- withInfo(op, allChildren: _*)
- None
- }
- }
-
- case join: HashJoin =>
- // `HashJoin` has only two implementations in Spark, but we check the
type of the join to
- // make sure we are handling the correct join type.
- if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
- join.isInstanceOf[ShuffledHashJoinExec]) &&
- !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
- join.isInstanceOf[BroadcastHashJoinExec])) {
- withInfo(join, s"Invalid hash join type ${join.nodeName}")
- return None
- }
-
- if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
- // https://github.com/apache/datafusion-comet/issues/457
- withInfo(join, "BuildRight with LeftAnti is not supported")
- return None
- }
-
- val condition = join.condition.map { cond =>
- val condProto = exprToProto(cond, join.left.output ++
join.right.output)
- if (condProto.isEmpty) {
- withInfo(join, cond)
- return None
- }
- condProto.get
- }
-
- val joinType = join.joinType match {
- case Inner => JoinType.Inner
- case LeftOuter => JoinType.LeftOuter
- case RightOuter => JoinType.RightOuter
- case FullOuter => JoinType.FullOuter
- case LeftSemi => JoinType.LeftSemi
- case LeftAnti => JoinType.LeftAnti
- case _ =>
- // Spark doesn't support other join types
- withInfo(join, s"Unsupported join type ${join.joinType}")
- return None
- }
-
- val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
- val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
-
- if (leftKeys.forall(_.isDefined) &&
- rightKeys.forall(_.isDefined) &&
- childOp.nonEmpty) {
- val joinBuilder = OperatorOuterClass.HashJoin
- .newBuilder()
- .setJoinType(joinType)
- .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
- .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
- .setBuildSide(
- if (join.buildSide == BuildLeft) BuildSide.BuildLeft else
BuildSide.BuildRight)
- condition.foreach(joinBuilder.setCondition)
- Some(builder.setHashJoin(joinBuilder).build())
- } else {
- val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
- withInfo(join, allExprs: _*)
- None
- }
-
- case join: SortMergeJoinExec if
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
- // `requiredOrders` and `getKeyOrdering` are copied from Spark's
SortMergeJoinExec.
- def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
- keys.map(SortOrder(_, Ascending))
- }
-
- def getKeyOrdering(
- keys: Seq[Expression],
- childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
- val requiredOrdering = requiredOrders(keys)
- if (SortOrder.orderingSatisfies(childOutputOrdering,
requiredOrdering)) {
- keys.zip(childOutputOrdering).map { case (key, childOrder) =>
- val sameOrderExpressionsSet = ExpressionSet(childOrder.children)
- key
- SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
- }
- } else {
- requiredOrdering
- }
- }
-
- if (join.condition.isDefined &&
- !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED
- .get(conf)) {
- withInfo(
- join,
-
s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not
enabled",
- join.condition.get)
- return None
- }
-
- val condition = join.condition.map { cond =>
- val condProto = exprToProto(cond, join.left.output ++
join.right.output)
- if (condProto.isEmpty) {
- withInfo(join, cond)
- return None
- }
- condProto.get
- }
-
- val joinType = join.joinType match {
- case Inner => JoinType.Inner
- case LeftOuter => JoinType.LeftOuter
- case RightOuter => JoinType.RightOuter
- case FullOuter => JoinType.FullOuter
- case LeftSemi => JoinType.LeftSemi
- case LeftAnti => JoinType.LeftAnti
- case _ =>
- // Spark doesn't support other join types
- withInfo(op, s"Unsupported join type ${join.joinType}")
- return None
- }
-
- // Checks if the join keys are supported by DataFusion SortMergeJoin.
- val errorMsgs = join.leftKeys.flatMap { key =>
- if (!supportedSortMergeJoinEqualType(key.dataType)) {
- Some(s"Unsupported join key type ${key.dataType} on key:
${key.sql}")
- } else {
- None
- }
- }
-
- if (errorMsgs.nonEmpty) {
- withInfo(op, errorMsgs.flatten.mkString("\n"))
- return None
- }
-
- val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
- val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
-
- val sortOptions = getKeyOrdering(join.leftKeys,
join.left.outputOrdering)
- .map(exprToProto(_, join.left.output))
-
- if (sortOptions.forall(_.isDefined) &&
- leftKeys.forall(_.isDefined) &&
- rightKeys.forall(_.isDefined) &&
- childOp.nonEmpty) {
- val joinBuilder = OperatorOuterClass.SortMergeJoin
- .newBuilder()
- .setJoinType(joinType)
- .addAllSortOptions(sortOptions.map(_.get).asJava)
- .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
- .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
- condition.map(joinBuilder.setCondition)
- Some(builder.setSortMergeJoin(joinBuilder).build())
- } else {
- val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
- withInfo(join, allExprs: _*)
- None
- }
-
- case join: SortMergeJoinExec if
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
- withInfo(join, "SortMergeJoin is not enabled")
- None
-
case op if isCometSink(op) =>
val supportedTypes =
op.output.forall(a => supportedDataType(a.dataType, allowComplex =
true))
@@ -1581,7 +1095,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
// scalastyle:off
-
/**
* Align w/ Arrow's
*
[[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40
can_rank]] and
@@ -1589,7 +1102,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
*
* TODO: Include SparkSQL's [[YearMonthIntervalType]] and
[[DayTimeIntervalType]]
*/
- // scalastyle:off
+ // scalastyle:on
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
def canRank(dt: DataType): Boolean = {
dt match {
@@ -1626,83 +1139,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
}
}
- private def validatePartitionAndSortSpecsForWindowFunc(
- partitionSpec: Seq[Expression],
- orderSpec: Seq[SortOrder],
- op: SparkPlan): Boolean = {
- if (partitionSpec.length != orderSpec.length) {
- return false
- }
-
- val partitionColumnNames = partitionSpec.collect {
- case a: AttributeReference => a.name
- case other =>
- withInfo(op, s"Unsupported partition expression:
${other.getClass.getSimpleName}")
- return false
- }
-
- val orderColumnNames = orderSpec.collect { case s: SortOrder =>
- s.child match {
- case a: AttributeReference => a.name
- case other =>
- withInfo(op, s"Unsupported sort expression:
${other.getClass.getSimpleName}")
- return false
- }
- }
-
- if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol,
orderCol) =>
- partCol != orderCol
- }) {
- withInfo(op, "Partitioning and sorting specifications must be the same.")
- return false
- }
-
- true
- }
-
- private def schema2Proto(
- fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField]
= {
- val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
- fields.map(field => {
- fieldBuilder.setName(field.name)
- fieldBuilder.setDataType(serializeDataType(field.dataType).get)
- fieldBuilder.setNullable(field.nullable)
- fieldBuilder.build()
- })
- }
-
- private def partition2Proto(
- partition: FilePartition,
- nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
- partitionSchema: StructType): Unit = {
- val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
- partition.files.foreach(file => {
- // Process the partition values
- val partitionValues = file.partitionValues
- assert(partitionValues.numFields == partitionSchema.length)
- val partitionVals =
- partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value,
i) =>
- val attr = partitionSchema(i)
- val valueProto = exprToProto(Literal(value, attr.dataType),
Seq.empty)
- // In `CometScanRule`, we have already checked that all partition
values are
- // supported. So, we can safely use `get` here.
- assert(
- valueProto.isDefined,
- s"Unsupported partition value: $value, type: ${attr.dataType}")
- valueProto.get
- }
-
- val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
- partitionVals.foreach(fileBuilder.addPartitionValues)
- fileBuilder
- .setFilePath(file.filePath.toString)
- .setStart(file.start)
- .setLength(file.length)
- .setFileSize(file.fileSize)
- partitionBuilder.addPartitionedFile(fileBuilder.build())
- })
- nativeScanBuilder.addFilePartitions(partitionBuilder.build())
- }
}
sealed trait SupportLevel
@@ -1726,131 +1162,3 @@ case class Incompatible(notes: Option[String] = None)
extends SupportLevel
/** Comet does not support this feature */
case class Unsupported(notes: Option[String] = None) extends SupportLevel
-
-/**
- * Trait for providing serialization logic for operators.
- */
-trait CometOperatorSerde[T <: SparkPlan] {
-
- /**
- * Convert a Spark operator into a protocol buffer representation that can
be passed into native
- * code.
- *
- * @param op
- * The Spark operator.
- * @param builder
- * The protobuf builder for the operator.
- * @param childOp
- * Child operators that have already been converted to Comet.
- * @return
- * Protocol buffer representation, or None if the operator could not be
converted. In this
- * case it is expected that the input operator will have been tagged with
reasons why it could
- * not be converted.
- */
- def convert(
- op: T,
- builder: Operator.Builder,
- childOp: Operator*): Option[OperatorOuterClass.Operator]
-
- /**
- * Get the optional Comet configuration entry that is used to enable or
disable native support
- * for this operator.
- */
- def enabledConfig: Option[ConfigEntry[Boolean]]
-}
-
-/**
- * Trait for providing serialization logic for expressions.
- */
-trait CometExpressionSerde[T <: Expression] {
-
- /**
- * Get a short name for the expression that can be used as part of a config
key related to the
- * expression, such as enabling or disabling that expression.
- *
- * @param expr
- * The Spark expression.
- * @return
- * Short name for the expression, defaulting to the Spark class name
- */
- def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
-
- /**
- * Determine the support level of the expression based on its attributes.
- *
- * @param expr
- * The Spark expression.
- * @return
- * Support level (Compatible, Incompatible, or Unsupported).
- */
- def getSupportLevel(expr: T): SupportLevel = Compatible(None)
-
- /**
- * Convert a Spark expression into a protocol buffer representation that can
be passed into
- * native code.
- *
- * @param expr
- * The Spark expression.
- * @param inputs
- * The input attributes.
- * @param binding
- * Whether the attributes are bound (this is only relevant in aggregate
expressions).
- * @return
- * Protocol buffer representation, or None if the expression could not be
converted. In this
- * case it is expected that the input expression will have been tagged
with reasons why it
- * could not be converted.
- */
- def convert(expr: T, inputs: Seq[Attribute], binding: Boolean):
Option[ExprOuterClass.Expr]
-}
-
-/**
- * Trait for providing serialization logic for aggregate expressions.
- */
-trait CometAggregateExpressionSerde[T <: AggregateFunction] {
-
- /**
- * Get a short name for the expression that can be used as part of a config
key related to the
- * expression, such as enabling or disabling that expression.
- *
- * @param expr
- * The Spark expression.
- * @return
- * Short name for the expression, defaulting to the Spark class name
- */
- def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
-
- /**
- * Convert a Spark expression into a protocol buffer representation that can
be passed into
- * native code.
- *
- * @param aggExpr
- * The aggregate expression.
- * @param expr
- * The aggregate function.
- * @param inputs
- * The input attributes.
- * @param binding
- * Whether the attributes are bound (this is only relevant in aggregate
expressions).
- * @param conf
- * SQLConf
- * @return
- * Protocol buffer representation, or None if the expression could not be
converted. In this
- * case it is expected that the input expression will have been tagged
with reasons why it
- * could not be converted.
- */
- def convert(
- aggExpr: AggregateExpression,
- expr: T,
- inputs: Seq[Attribute],
- binding: Boolean,
- conf: SQLConf): Option[ExprOuterClass.AggExpr]
-}
-
-/** Serde for scalar function. */
-case class CometScalarFunction[T <: Expression](name: String) extends
CometExpressionSerde[T] {
- override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean):
Option[Expr] = {
- val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
- val optExpr = scalarFunctionExprToProto(name, childExpr: _*)
- optExprWithInfo(optExpr, expr, expr.children: _*)
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]