This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 0826772  feat: Support Broadcast HashJoin (#211)
0826772 is described below

commit 08267724ebc738d43d888cbe83a4e58acb17d6af
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Mar 26 11:17:47 2024 -0700

    feat: Support Broadcast HashJoin (#211)
    
    * feat: Support HashJoin
    
    * Add comment
    
    * Clean up test
    
    * Fix join filter
    
    * Fix clippy
    
    * Use consistent function with sort merge join
    
    * Add note about left semi and left anti joins
    
    * feat: Support BroadcastHashJoin
    
    * Move tests
    
    * Remove unused import
    
    * Add function to parse join parameters
    
    * Remove duplicate code
    
    * For review
---
 .../org/apache/comet/CometArrowStreamWriter.java   | 51 +++++++++++++++++
 .../scala/org/apache/comet/vector/NativeUtil.scala | 18 +++---
 .../execution/shuffle/ArrowReaderIterator.scala    | 16 +++---
 .../apache/comet/CometSparkSessionExtensions.scala | 35 +++++++++++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 18 ++++--
 .../sql/comet/CometBroadcastExchangeExec.scala     |  3 +-
 .../org/apache/spark/sql/comet/operators.scala     | 44 ++++++++++++++-
 .../org/apache/comet/exec/CometJoinSuite.scala     | 64 ++++++++++++++++++++++
 8 files changed, 221 insertions(+), 28 deletions(-)

diff --git a/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java 
b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
new file mode 100644
index 0000000..a492ce8
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.compression.NoCompressionCodec;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+
+/**
+ * A custom `ArrowStreamWriter` that allows writing batches from different 
root to the same stream.
+ * Arrow `ArrowStreamWriter` cannot change the root after initialization.
+ */
+public class CometArrowStreamWriter extends ArrowStreamWriter {
+  public CometArrowStreamWriter(
+      VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel 
out) {
+    super(root, provider, out);
+  }
+
+  public void writeMoreBatch(VectorSchemaRoot root) throws IOException {
+    VectorUnloader unloader =
+        new VectorUnloader(
+            root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, 
/*alignBuffers*/ true);
+
+    try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
+      writeRecordBatch(batch);
+    }
+  }
+}
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala 
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 1682295..cc726e3 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -20,6 +20,7 @@
 package org.apache.comet.vector
 
 import java.io.OutputStream
+import java.nio.channels.Channels
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, 
ArrowSchema, CDataDictiona
 import org.apache.arrow.memory.RootAllocator
 import org.apache.arrow.vector._
 import org.apache.arrow.vector.dictionary.DictionaryProvider
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
 import org.apache.spark.SparkException
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+import org.apache.comet.CometArrowStreamWriter
+
 class NativeUtil {
   private val allocator = new RootAllocator(Long.MaxValue)
   private val dictionaryProvider: CDataDictionaryProvider = new 
CDataDictionaryProvider
@@ -46,29 +48,27 @@ class NativeUtil {
    *   the output stream
    */
   def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): 
Long = {
-    var schemaRoot: Option[VectorSchemaRoot] = None
-    var writer: Option[ArrowStreamWriter] = None
+    var writer: Option[CometArrowStreamWriter] = None
     var rowCount = 0
 
     batches.foreach { batch =>
       val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
-      val root = schemaRoot.getOrElse(new 
VectorSchemaRoot(fieldVectors.asJava))
+      val root = new VectorSchemaRoot(fieldVectors.asJava)
       val provider = batchProviderOpt.getOrElse(dictionaryProvider)
 
       if (writer.isEmpty) {
-        writer = Some(new ArrowStreamWriter(root, provider, out))
+        writer = Some(new CometArrowStreamWriter(root, provider, 
Channels.newChannel(out)))
         writer.get.start()
+        writer.get.writeBatch()
+      } else {
+        writer.get.writeMoreBatch(root)
       }
-      writer.get.writeBatch()
 
       root.clear()
-      schemaRoot = Some(root)
-
       rowCount += batch.numRows()
     }
 
     writer.map(_.end())
-    schemaRoot.map(_.close())
 
     rowCount
   }
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index e8dba93..304c3ce 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel
 
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
-import org.apache.comet.vector.StreamReader
+import org.apache.comet.vector._
 
 class ArrowReaderIterator(channel: ReadableByteChannel) extends 
Iterator[ColumnarBatch] {
 
@@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) 
extends Iterator[Columna
       return true
     }
 
+    // Release the previous batch.
+    // If it is not released, when closing the reader, arrow library will 
complain about
+    // memory leak.
+    if (currentBatch != null) {
+      currentBatch.close()
+    }
+
     batch = nextBatch()
     if (batch.isEmpty) {
       return false
@@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) 
extends Iterator[Columna
 
     val nextBatch = batch.get
 
-    // Release the previous batch.
-    // If it is not released, when closing the reader, arrow library will 
complain about
-    // memory leak.
-    if (currentBatch != null) {
-      currentBatch.close()
-    }
-
     currentBatch = nextBatch
     batch = None
     currentBatch
diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 1380ee9..fcbf42f 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -356,6 +356,27 @@ class CometSparkSessionExtensions
               op
           }
 
+        case op: BroadcastHashJoinExec
+            if isCometOperatorEnabled(conf, "broadcast_hash_join") &&
+              op.children.forall(isCometNative(_)) =>
+          val newOp = transform1(op)
+          newOp match {
+            case Some(nativeOp) =>
+              CometBroadcastHashJoinExec(
+                nativeOp,
+                op,
+                op.leftKeys,
+                op.rightKeys,
+                op.joinType,
+                op.condition,
+                op.buildSide,
+                op.left,
+                op.right,
+                SerializedPlan(None))
+            case None =>
+              op
+          }
+
         case op: SortMergeJoinExec
             if isCometOperatorEnabled(conf, "sort_merge_join") &&
               op.children.forall(isCometNative(_)) =>
@@ -411,6 +432,16 @@ class CometSparkSessionExtensions
               u
           }
 
+        // For AQE broadcast stage on a Comet broadcast exchange
+        case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) 
=>
+          val newOp = transform1(s)
+          newOp match {
+            case Some(nativeOp) =>
+              CometSinkPlaceHolder(nativeOp, s, s)
+            case None =>
+              s
+          }
+
         case b: BroadcastExchangeExec
             if isCometNative(b.child) && isCometOperatorEnabled(conf, 
"broadcastExchangeExec") &&
               isCometBroadCastEnabled(conf) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index bf2510b..b98c438 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometSinkPlaceHolder, DecimalPrecision}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, 
ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -1915,7 +1915,16 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
           }
         }
 
-      case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, 
"hash_join") =>
+      case join: HashJoin =>
+        // `HashJoin` has only two implementations in Spark, but we check the 
type of the join to
+        // make sure we are handling the correct join type.
+        if (!(isCometOperatorEnabled(op.conf, "hash_join") &&
+            join.isInstanceOf[ShuffledHashJoinExec]) &&
+          !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") &&
+            join.isInstanceOf[BroadcastHashJoinExec])) {
+          return None
+        }
+
         if (join.buildSide == BuildRight) {
           // DataFusion HashJoin assumes build side is always left.
           // TODO: support BuildRight
@@ -2063,6 +2072,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
       case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
       case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) => true
       case _: TakeOrderedAndProjectExec => true
+      case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
       case _: BroadcastExchangeExec => true
       case _ => false
     }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index f115b2a..24f9f32 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -258,8 +258,7 @@ class CometBatchRDD(
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
     val partition = split.asInstanceOf[CometBatchPartition]
-
-    partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
+    partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
   }
 }
 
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index fb300a3..84734a1 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, 
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, 
UnaryExecNode}
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
@@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec {
     plan match {
       case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec 
|
           _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: 
CometUnionExec |
-          _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: 
ReusedExchangeExec =>
+          _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: 
ReusedExchangeExec |
+          _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
         func(plan)
       case _: CometPlan =>
         // Other Comet operators, continue to traverse the tree.
@@ -622,7 +623,44 @@ case class CometHashJoinExec(
   }
 
   override def hashCode(): Int =
-    Objects.hashCode(leftKeys, rightKeys, condition, left, right)
+    Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
+}
+
+case class CometBroadcastHashJoinExec(
+    override val nativeOp: Operator,
+    override val originalPlan: SparkPlan,
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    buildSide: BuildSide,
+    override val left: SparkPlan,
+    override val right: SparkPlan,
+    override val serializedPlanOpt: SerializedPlan)
+    extends CometBinaryExec {
+  override def withNewChildrenInternal(newLeft: SparkPlan, newRight: 
SparkPlan): SparkPlan =
+    this.copy(left = newLeft, right = newRight)
+
+  override def stringArgs: Iterator[Any] =
+    Iterator(leftKeys, rightKeys, joinType, condition, left, right)
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: CometBroadcastHashJoinExec =>
+        this.leftKeys == other.leftKeys &&
+        this.rightKeys == other.rightKeys &&
+        this.condition == other.condition &&
+        this.buildSide == other.buildSide &&
+        this.left == other.left &&
+        this.right == other.right &&
+        this.serializedPlanOpt == other.serializedPlanOpt
+      case _ =>
+        false
+    }
+  }
+
+  override def hashCode(): Int =
+    Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
 }
 
 case class CometSortMergeJoinExec(
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index a64ec87..6f479e3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -23,9 +23,11 @@ import org.scalactic.source.Position
 import org.scalatest.Tag
 
 import org.apache.spark.sql.CometTestBase
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometBroadcastHashJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 
 import org.apache.comet.CometConf
+import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
 
 class CometJoinSuite extends CometTestBase {
 
@@ -38,6 +40,68 @@ class CometJoinSuite extends CometTestBase {
     }
   }
 
+  test("Broadcast HashJoin without join filter") {
+    assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 
3.4+")
+    withSQLConf(
+      CometConf.COMET_BATCH_SIZE.key -> "100",
+      SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+      "spark.comet.exec.broadcast.enabled" -> "true",
+      "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
+          // Inner join: build left
+          val df1 =
+            sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON 
tbl_a._2 = tbl_b._1")
+          checkSparkAnswerAndOperator(
+            df1,
+            Seq(classOf[CometBroadcastExchangeExec], 
classOf[CometBroadcastHashJoinExec]))
+
+          // Right join: build left
+          val df2 =
+            sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b 
ON tbl_a._2 = tbl_b._1")
+          checkSparkAnswerAndOperator(
+            df2,
+            Seq(classOf[CometBroadcastExchangeExec], 
classOf[CometBroadcastHashJoinExec]))
+        }
+      }
+    }
+  }
+
+  test("Broadcast HashJoin with join filter") {
+    assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 
3.4+")
+    withSQLConf(
+      CometConf.COMET_BATCH_SIZE.key -> "100",
+      SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+      "spark.comet.exec.broadcast.enabled" -> "true",
+      "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
+          // Inner join: build left
+          val df1 =
+            sql(
+              "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
+                "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+          checkSparkAnswerAndOperator(
+            df1,
+            Seq(classOf[CometBroadcastExchangeExec], 
classOf[CometBroadcastHashJoinExec]))
+
+          // Right join: build left
+          val df2 =
+            sql(
+              "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " +
+                "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+          checkSparkAnswerAndOperator(
+            df2,
+            Seq(classOf[CometBroadcastExchangeExec], 
classOf[CometBroadcastHashJoinExec]))
+        }
+      }
+    }
+  }
+
   test("HashJoin without join filter") {
     withSQLConf(
       SQLConf.PREFER_SORTMERGEJOIN.key -> "false",

Reply via email to