cloud-fan commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r469823249



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)
+                  found = true
+                  return true
+                }
+              }
+              if (!found) {

Review comment:
       let's add some comments. When we reach here, it means no match is found 
for this key so we need to return one row with build side null row, to match 
the full outer join semantic.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to