agrawaldevesh commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r470091083



##########
File path: core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
##########
@@ -601,6 +657,14 @@ public boolean isDefined() {
       return isDefined;
     }
 
+    /**
+     * Returns index for key.
+     */
+    public int getKeyIndex() {

Review comment:
       See comment above about possibility eliminating this notion of keyIndex 
and sticking with pos.

##########
File path: core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
##########
@@ -428,6 +428,62 @@ public MapIterator destructiveIterator() {
     return new MapIterator(numValues, new Location(), true);
   }
 
+  /**
+   * Iterator for the entries of this map. This is to first iterate over key 
index array
+   * `longArray` then accessing values in `dataPages`. NOTE: this is different 
from `MapIterator`
+   * in the sense that key index is preserved here
+   * (See `UnsafeHashedRelation` for example of usage).
+   */
+  public final class MapIteratorWithKeyIndex implements Iterator<Location> {
+
+    private int keyIndex = 0;
+    private int numRecords;
+    private final Location loc;
+
+    private MapIteratorWithKeyIndex(int numRecords, Location loc) {
+      this.numRecords = numRecords;
+      this.loc = loc;
+    }
+
+    @Override
+    public boolean hasNext() {
+      return numRecords > 0;
+    }
+
+    @Override
+    public Location next() {
+      if (!loc.isDefined() || !loc.nextValue()) {
+        while (longArray.get(keyIndex * 2) == 0) {
+          keyIndex++;
+        }
+        loc.with(keyIndex, (int) longArray.get(keyIndex * 2 + 1), true);
+        keyIndex++;

Review comment:
       Should there be any bounds check done on `numRecords` to ensure that 
keyIndex won't wrap around ? Or is this tooo internal an iterator to care about 
this ? 
   
   Basically keyIndex can grow beyond the longArray.size() if numRecords is 
sufficiently big ? 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
##########
@@ -116,7 +116,9 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
    *
    * - Shuffle hash join:
    *     Only supported for equi-joins, while the join keys do not need to be 
sortable.
-   *     Supported for all join types except full outer joins.
+   *     Supported for all join types.
+   *     Building hash map from table is a memory-intensive operation and it 
could cause OOM

Review comment:
       Should we add some commentary about what is meant by "hash map" here ? 
Is it the hash map you are using for storing the matched-bits or is it the 
build hash table ? (something true of all hash joins ?) 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -66,6 +66,30 @@ private[execution] sealed trait HashedRelation extends 
KnownSizeEstimation {
     throw new UnsupportedOperationException
   }
 
+  /**
+   * Returns key index and matched rows.

Review comment:
       Same comment here. I cannot find the concept "key index" anywhere else 
in the Spark code. Is there another "commonly used name" for "key index / 
keyIndex" that can use ? 
   
   If not, lets figure out where to clearly define it.
   

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)

Review comment:
       I actually have a similar comment about `valueIndex` (like `keyIndex` 
above). Can we reuse an existing term or it or define this semantic somewhere 
above clearly ? Its not used anywhere else in Spark, esp not in the join 
context. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)
+                  found = true
+                  return true
+                }
+              }
+              if (!found) {
+                joinRowWithBuild(buildNullRow)
+                found = true

Review comment:
       I didn't follow why found is again set to true here ? I think it can be 
left alone ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))

Review comment:
       Same comment as below for the use of flatMap/Option here: Is this 
pattern used in the hot path elsewhere in Spark for RDD iterator computation ?

##########
File path: core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
##########
@@ -428,6 +428,62 @@ public MapIterator destructiveIterator() {
     return new MapIterator(numValues, new Location(), true);
   }
 
+  /**
+   * Iterator for the entries of this map. This is to first iterate over key 
index array
+   * `longArray` then accessing values in `dataPages`. NOTE: this is different 
from `MapIterator`
+   * in the sense that key index is preserved here
+   * (See `UnsafeHashedRelation` for example of usage).
+   */
+  public final class MapIteratorWithKeyIndex implements Iterator<Location> {
+
+    private int keyIndex = 0;

Review comment:
       I am wondering if we MUST introduce a new terminology `keyIndex` in this 
class ? Is `pos` equivalent to `keyIndex` ? I think it's betters to stick to 
the existing concepts in the class unless it is quite a stretch.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -110,14 +138,39 @@ private[execution] object HashedRelation {
 
     if (!input.hasNext) {
       EmptyHashedRelation
-    } else if (key.length == 1 && key.head.dataType == LongType) {
+    } else if (key.length == 1 && key.head.dataType == LongType && 
!allowsNullKey) {
+      // NOTE: LongHashedRelation does not support NULL keys.
       LongHashedRelation(input, key, sizeEstimate, mm, isNullAware)
     } else {
-      UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware)
+      UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, 
allowsNullKey)
     }
   }
 }
 
+/**
+ * A wrapper for key index and value in InternalRow type.
+ * Designed to be instantiated once per thread and reused.
+ */
+private[execution] class ValueRowWithKeyIndex {
+  private var keyIndex: Int = _
+  private var value: InternalRow = _
+
+  /** Updates this ValueRowWithKeyIndex.  Returns itself. */
+  def updates(newKeyIndex: Int, newValue: InternalRow): ValueRowWithKeyIndex = 
{

Review comment:
       Should this be called `update` instead of `updates` ? (singular instead 
of plural)

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -179,6 +235,63 @@ private[joins] class UnsafeHashedRelation(
     }
   }
 
+  override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) 
= {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      (loc.getKeyIndex,
+        new Iterator[UnsafeRow] {
+          private var _hasNext = true
+          override def hasNext: Boolean = _hasNext
+          override def next(): UnsafeRow = {
+            resultRow.pointTo(loc.getValueBase, loc.getValueOffset, 
loc.getValueLength)
+            _hasNext = loc.nextValue()
+            resultRow
+          }
+        })
+    } else {
+      null
+    }
+  }
+
+  override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      resultRow.pointTo(loc.getValueBase, loc.getValueOffset, 
loc.getValueLength)
+      valueRowWithKeyIndex.updates(loc.getKeyIndex, resultRow)
+    } else {
+      null
+    }
+  }
+
+  override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = {
+    val iter = binaryMap.iteratorWithKeyIndex()
+
+    new Iterator[ValueRowWithKeyIndex] {
+      override def hasNext: Boolean = iter.hasNext
+
+      override def next(): ValueRowWithKeyIndex = {
+        if (!hasNext) {
+          throw new NoSuchElementException("End of the iterator")
+        }
+        val loc = iter.next()
+        resultRow.pointTo(loc.getValueBase, loc.getValueOffset, 
loc.getValueLength)

Review comment:
       Could it be problematic that resultRow is reuse here as a part of 
embedding it in valueRowWithKeyIndex ? Ideally I think we allocate a new fresh 
copy of the row owned by the valueRow that is then updated here. 
   
   Currently, as written, modifying `resultRow` without valueRow elsewhere 
might lead to issues.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)
+                  found = true
+                  return true
+                }
+              }
+              if (!found) {
+                joinRowWithBuild(buildNullRow)
+                found = true
+                return true
+              }
+              false
+            }
+            override def getRow: InternalRow = joinRow
+          }.toScala
+        } else {
+          Iterator.single(joinRowWithBuild(buildNullRow))
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    var prevKeyIndex = -1
+    var valueIndex = -1
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) {

Review comment:
       I understand what this logic does but I am finding it hard to read. 
Should it be more directly like this:
   
   ```
   val keyIndex = valueRowWithKeyIndex.getKeyIndex
   if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) {
     valueIndex = 0
   } else {
     valueIndex += 1
   }
   
   val isMatched = ....
   if (isMatched) {...
   } else {...
   }
   prevKeyIndex = keyIndex
   ```

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)
+                  found = true
+                  return true
+                }
+              }
+              if (!found) {
+                joinRowWithBuild(buildNullRow)
+                found = true
+                return true
+              }
+              false
+            }
+            override def getRow: InternalRow = joinRow
+          }.toScala
+        } else {
+          Iterator.single(joinRowWithBuild(buildNullRow))
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    var prevKeyIndex = -1
+    var valueIndex = -1
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) {
+          prevKeyIndex = keyIndex
+          valueIndex = -1
+        }
+        valueIndex += 1
+        val isMatched = isRowMatched(keyIndex, valueIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))

Review comment:
       Just curious to educate myself: Are we really creating an Option for 
each build row, or will this be elided away ? In other words, is there a perf 
penalty of using flatMap here ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {
+                  markRowMatched(keyIndex, valueIndex)
+                  found = true
+                  return true
+                }
+              }
+              if (!found) {
+                joinRowWithBuild(buildNullRow)
+                found = true
+                return true
+              }
+              false
+            }
+            override def getRow: InternalRow = joinRow
+          }.toScala
+        } else {
+          Iterator.single(joinRowWithBuild(buildNullRow))
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and

Review comment:
       Can this comment be shortened to "Process build side filtering out the 
matched rows" ? The concept of 'match' has already been introduced above to 
include the join condition.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val matched = hashedRelation.getWithKeyIndex(keys)
+        if (matched != null) {
+          val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex)
+
+          new RowIterator {
+            private var found = false
+            override def advanceNext(): Boolean = {
+              while (buildIter.hasNext) {
+                val (buildRow, valueIndex) = buildIter.next()
+                if (boundCondition(joinRowWithBuild(buildRow))) {

Review comment:
       Just checking: When the join does not have any condition, this 
`boundCondition` vaccuously evaluates to true, right ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,210 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    val streamNullJoinRow = new JoinedRow
+    val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          streamNullJoinRow.withRight(streamNullRow)
+          streamNullJoinRow.withLeft _
+        case BuildRight =>
+          streamNullJoinRow.withLeft(streamNullRow)
+          streamNullJoinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out rows looked up and
+    // passed join condition already
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]

Review comment:
       I found `OpenHashSet` , `PrimitiveKeyOpenHashMap` and other friends. I 
don't have context on why Spark does not use fastutils, but any of those have 
got to be better than HashSet[Long].
   
   You might also want to add a quick comment that `Long` here is (keyIndex << 
32 | valueIndex).
   
   In addition, would a more accurate renaming of valueIndex be 
valueIndexForSameKey ? (not saying we have to adopt this name but just trying 
to follow the notion)

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -314,7 +338,9 @@ private[joins] object UnsafeHashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int,
       taskMemoryManager: TaskMemoryManager,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,

Review comment:
       ^^^ So now isNullAware can work with allowsNullKey ? I don't see an 
exclusive or check below anymore.

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1188,4 +1188,53 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
         classOf[BroadcastNestedLoopJoinExec]))
     }
   }
+
+  test("SPARK-32399: Full outer shuffled hash join") {
+    val inputDFs = Seq(
+      // Test unique join key
+      (spark.range(10).selectExpr("id as k1"),
+        spark.range(30).selectExpr("id as k2"),
+        $"k1" === $"k2"),
+      // Test non-unique join key
+      (spark.range(10).selectExpr("id % 5 as k1"),
+        spark.range(30).selectExpr("id % 5 as k2"),
+        $"k1" === $"k2"),
+      // Test string join key
+      (spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
+        spark.range(30).selectExpr("cast(id as string) as k2"),
+        $"k1" === $"k2"),
+      // Test build side at right
+      (spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
+        spark.range(10).selectExpr("cast(id as string) as k2"),
+        $"k1" === $"k2"),
+      // Test NULL join key
+      (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value 
as k1"),
+        spark.range(30).map(i => if (i % 4 == 0) i else 
null).selectExpr("value as k2"),
+        $"k1" === $"k2"),
+      // Test multiple join keys
+      (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
+        "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as 
long) as k3"),
+        spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr(
+          "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as 
long) as k6"),
+        $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
+    )
+    inputDFs.foreach { case (df1, df2, joinExprs) =>
+      withSQLConf(
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",

Review comment:
       The code pointers are helpful ! I am wondering if you could add some 
"derivation" for the use of the magic constant 80 in the code comments ? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to