This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 0826772 feat: Support Broadcast HashJoin (#211)
0826772 is described below
commit 08267724ebc738d43d888cbe83a4e58acb17d6af
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Mar 26 11:17:47 2024 -0700
feat: Support Broadcast HashJoin (#211)
* feat: Support HashJoin
* Add comment
* Clean up test
* Fix join filter
* Fix clippy
* Use consistent function with sort merge join
* Add note about left semi and left anti joins
* feat: Support BroadcastHashJoin
* Move tests
* Remove unused import
* Add function to parse join parameters
* Remove duplicate code
* For review
---
.../org/apache/comet/CometArrowStreamWriter.java | 51 +++++++++++++++++
.../scala/org/apache/comet/vector/NativeUtil.scala | 18 +++---
.../execution/shuffle/ArrowReaderIterator.scala | 16 +++---
.../apache/comet/CometSparkSessionExtensions.scala | 35 +++++++++++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 18 ++++--
.../sql/comet/CometBroadcastExchangeExec.scala | 3 +-
.../org/apache/spark/sql/comet/operators.scala | 44 ++++++++++++++-
.../org/apache/comet/exec/CometJoinSuite.scala | 64 ++++++++++++++++++++++
8 files changed, 221 insertions(+), 28 deletions(-)
diff --git a/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
new file mode 100644
index 0000000..a492ce8
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.compression.NoCompressionCodec;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+
+/**
+ * A custom `ArrowStreamWriter` that allows writing batches from different
root to the same stream.
+ * Arrow `ArrowStreamWriter` cannot change the root after initialization.
+ */
+public class CometArrowStreamWriter extends ArrowStreamWriter {
+ public CometArrowStreamWriter(
+ VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel
out) {
+ super(root, provider, out);
+ }
+
+ public void writeMoreBatch(VectorSchemaRoot root) throws IOException {
+ VectorUnloader unloader =
+ new VectorUnloader(
+ root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE,
/*alignBuffers*/ true);
+
+ try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
+ writeRecordBatch(batch);
+ }
+ }
+}
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 1682295..cc726e3 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -20,6 +20,7 @@
package org.apache.comet.vector
import java.io.OutputStream
+import java.nio.channels.Channels
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter,
ArrowSchema, CDataDictiona
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.comet.CometArrowStreamWriter
+
class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new
CDataDictionaryProvider
@@ -46,29 +48,27 @@ class NativeUtil {
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream):
Long = {
- var schemaRoot: Option[VectorSchemaRoot] = None
- var writer: Option[ArrowStreamWriter] = None
+ var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0
batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
- val root = schemaRoot.getOrElse(new
VectorSchemaRoot(fieldVectors.asJava))
+ val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)
if (writer.isEmpty) {
- writer = Some(new ArrowStreamWriter(root, provider, out))
+ writer = Some(new CometArrowStreamWriter(root, provider,
Channels.newChannel(out)))
writer.get.start()
+ writer.get.writeBatch()
+ } else {
+ writer.get.writeMoreBatch(root)
}
- writer.get.writeBatch()
root.clear()
- schemaRoot = Some(root)
-
rowCount += batch.numRows()
}
writer.map(_.end())
- schemaRoot.map(_.close())
rowCount
}
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index e8dba93..304c3ce 100644
---
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel
import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.comet.vector.StreamReader
+import org.apache.comet.vector._
class ArrowReaderIterator(channel: ReadableByteChannel) extends
Iterator[ColumnarBatch] {
@@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel)
extends Iterator[Columna
return true
}
+ // Release the previous batch.
+ // If it is not released, when closing the reader, arrow library will
complain about
+ // memory leak.
+ if (currentBatch != null) {
+ currentBatch.close()
+ }
+
batch = nextBatch()
if (batch.isEmpty) {
return false
@@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel)
extends Iterator[Columna
val nextBatch = batch.get
- // Release the previous batch.
- // If it is not released, when closing the reader, arrow library will
complain about
- // memory leak.
- if (currentBatch != null) {
- currentBatch.close()
- }
-
currentBatch = nextBatch
batch = None
currentBatch
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 1380ee9..fcbf42f 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -356,6 +356,27 @@ class CometSparkSessionExtensions
op
}
+ case op: BroadcastHashJoinExec
+ if isCometOperatorEnabled(conf, "broadcast_hash_join") &&
+ op.children.forall(isCometNative(_)) =>
+ val newOp = transform1(op)
+ newOp match {
+ case Some(nativeOp) =>
+ CometBroadcastHashJoinExec(
+ nativeOp,
+ op,
+ op.leftKeys,
+ op.rightKeys,
+ op.joinType,
+ op.condition,
+ op.buildSide,
+ op.left,
+ op.right,
+ SerializedPlan(None))
+ case None =>
+ op
+ }
+
case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
@@ -411,6 +432,16 @@ class CometSparkSessionExtensions
u
}
+ // For AQE broadcast stage on a Comet broadcast exchange
+ case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _)
=>
+ val newOp = transform1(s)
+ newOp match {
+ case Some(nativeOp) =>
+ CometSinkPlaceHolder(nativeOp, s, s)
+ case None =>
+ s
+ }
+
case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf,
"broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index bf2510b..b98c438 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1915,7 +1915,16 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
}
}
- case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf,
"hash_join") =>
+ case join: HashJoin =>
+ // `HashJoin` has only two implementations in Spark, but we check the
type of the join to
+ // make sure we are handling the correct join type.
+ if (!(isCometOperatorEnabled(op.conf, "hash_join") &&
+ join.isInstanceOf[ShuffledHashJoinExec]) &&
+ !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") &&
+ join.isInstanceOf[BroadcastHashJoinExec])) {
+ return None
+ }
+
if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
@@ -2063,6 +2072,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
+ case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case _: BroadcastExchangeExec => true
case _ => false
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index f115b2a..24f9f32 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -258,8 +258,7 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
-
- partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
+ partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index fb300a3..84734a1 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec,
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan,
UnaryExecNode}
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
@@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec {
plan match {
case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec
|
_: AQEShuffleReadExec | _: CometShuffleExchangeExec | _:
CometUnionExec |
- _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _:
ReusedExchangeExec =>
+ _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _:
ReusedExchangeExec |
+ _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
func(plan)
case _: CometPlan =>
// Other Comet operators, continue to traverse the tree.
@@ -622,7 +623,44 @@ case class CometHashJoinExec(
}
override def hashCode(): Int =
- Objects.hashCode(leftKeys, rightKeys, condition, left, right)
+ Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
+}
+
+case class CometBroadcastHashJoinExec(
+ override val nativeOp: Operator,
+ override val originalPlan: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ buildSide: BuildSide,
+ override val left: SparkPlan,
+ override val right: SparkPlan,
+ override val serializedPlanOpt: SerializedPlan)
+ extends CometBinaryExec {
+ override def withNewChildrenInternal(newLeft: SparkPlan, newRight:
SparkPlan): SparkPlan =
+ this.copy(left = newLeft, right = newRight)
+
+ override def stringArgs: Iterator[Any] =
+ Iterator(leftKeys, rightKeys, joinType, condition, left, right)
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: CometBroadcastHashJoinExec =>
+ this.leftKeys == other.leftKeys &&
+ this.rightKeys == other.rightKeys &&
+ this.condition == other.condition &&
+ this.buildSide == other.buildSide &&
+ this.left == other.left &&
+ this.right == other.right &&
+ this.serializedPlanOpt == other.serializedPlanOpt
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode(): Int =
+ Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
}
case class CometSortMergeJoinExec(
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index a64ec87..6f479e3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -23,9 +23,11 @@ import org.scalactic.source.Position
import org.scalatest.Tag
import org.apache.spark.sql.CometTestBase
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometBroadcastHashJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.comet.CometConf
+import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometJoinSuite extends CometTestBase {
@@ -38,6 +40,68 @@ class CometJoinSuite extends CometTestBase {
}
}
+ test("Broadcast HashJoin without join filter") {
+ assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark
3.4+")
+ withSQLConf(
+ CometConf.COMET_BATCH_SIZE.key -> "100",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+ "spark.comet.exec.broadcast.enabled" -> "true",
+ "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
+ // Inner join: build left
+ val df1 =
+ sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON
tbl_a._2 = tbl_b._1")
+ checkSparkAnswerAndOperator(
+ df1,
+ Seq(classOf[CometBroadcastExchangeExec],
classOf[CometBroadcastHashJoinExec]))
+
+ // Right join: build left
+ val df2 =
+ sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b
ON tbl_a._2 = tbl_b._1")
+ checkSparkAnswerAndOperator(
+ df2,
+ Seq(classOf[CometBroadcastExchangeExec],
classOf[CometBroadcastHashJoinExec]))
+ }
+ }
+ }
+ }
+
+ test("Broadcast HashJoin with join filter") {
+ assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark
3.4+")
+ withSQLConf(
+ CometConf.COMET_BATCH_SIZE.key -> "100",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+ "spark.comet.exec.broadcast.enabled" -> "true",
+ "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
+ // Inner join: build left
+ val df1 =
+ sql(
+ "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
+ "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+ checkSparkAnswerAndOperator(
+ df1,
+ Seq(classOf[CometBroadcastExchangeExec],
classOf[CometBroadcastHashJoinExec]))
+
+ // Right join: build left
+ val df2 =
+ sql(
+ "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " +
+ "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+ checkSparkAnswerAndOperator(
+ df2,
+ Seq(classOf[CometBroadcastExchangeExec],
classOf[CometBroadcastHashJoinExec]))
+ }
+ }
+ }
+ }
+
test("HashJoin without join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",