Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5717#discussion_r36493358
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 ---
    @@ -62,111 +100,250 @@ case class SortMergeJoin(
       }
     
       protected override def doExecute(): RDD[InternalRow] = {
    -    val leftResults = left.execute().map(_.copy())
    -    val rightResults = right.execute().map(_.copy())
    +    val streamResults = streamedPlan.execute().map(_.copy())
    +    val bufferResults = bufferedPlan.execute().map(_.copy())
     
    -    leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
    +    streamResults.zipPartitions(bufferResults) ( (streamedIter, 
bufferedIter) => {
    +      // standard null rows
    +      val streamedNullRow = 
InternalRow.fromSeq(Seq.fill(streamedPlan.output.length)(null))
    +      val bufferedNullRow = 
InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null))
           new Iterator[InternalRow] {
             // An ordering that can be used to compare keys from both sides.
             private[this] val keyOrdering = 
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
             // Mutable per row objects.
             private[this] val joinRow = new JoinedRow
    -        private[this] var leftElement: InternalRow = _
    -        private[this] var rightElement: InternalRow = _
    -        private[this] var leftKey: InternalRow = _
    -        private[this] var rightKey: InternalRow = _
    -        private[this] var rightMatches: CompactBuffer[InternalRow] = _
    -        private[this] var rightPosition: Int = -1
    +        private[this] var streamedElement: InternalRow = _
    +        private[this] var bufferedElement: InternalRow = _
    +        private[this] var streamedKey: InternalRow = _
    +        private[this] var bufferedKey: InternalRow = _
    +        private[this] var bufferedMatches: CompactBuffer[InternalRow] = _
    +        private[this] var bufferedPosition: Int = -1
             private[this] var stop: Boolean = false
             private[this] var matchKey: InternalRow = _
    +        // when we do merge algorithm and find some not matched join key, 
there must be a side
    +        // that do not have a corresponding match. So we need to mark 
which side it is. True means
    +        // streamed side not have match, and False means the buffered 
side. Only set when needed.
    +        private[this] var continueStreamed: Boolean = _
    +        private[this] var streamNullGenerated: Boolean = false
    +        // Tracks if each element in bufferedMatches have a matched 
streamedElement.
    +        private[this] var bitSet: BitSet = _
    +        // marks if the found result has been fetched.
    +        private[this] var found: Boolean = false
    +        private[this] var bufferNullGenerated: Boolean = false
     
             // initialize iterator
             initialize()
     
    -        override final def hasNext: Boolean = nextMatchingPair()
    +        override final def hasNext: Boolean = {
    +          val matching = nextMatchingBlock()
    +          if (matching && !isBufferEmpty(bufferedMatches)) {
    +            // The buffer stores all rows that match key, but condition 
may not be matched.
    +            // If none of rows in the buffer match condition, we'll fetch 
next matching block.
    +            findNextInBuffer() || hasNext
    +          } else {
    +            matching
    +          }
    +        }
    +
    +        /**
    +         * Run down the current `bufferedMatches` to find rows that match 
conditions.
    +         * If `joinType` is not `Inner`, we will use `bufferNullGenerated` 
to mark if
    +         * we need to build a bufferedNullRow for result.
    +         * If `joinType` is `FullOuter`, we will use `streamNullGenerated` 
to mark if
    +         * a buffered element need to join with a streamedNullRow.
    +         * The method can be called multiple times since `found` serves as 
a guardian.
    +         */
    +        def findNextInBuffer(): Boolean = {
    +          while (!found && streamedElement != null
    +            && keyOrdering.compare(streamedKey, matchKey) == 0) {
    +            while (bufferedPosition < bufferedMatches.size && 
!boundCondition(
    +              joinRow(streamedElement, 
bufferedMatches(bufferedPosition)))) {
    +              bufferedPosition += 1
    +            }
    +            if (bufferedPosition == bufferedMatches.size) {
    +              if (joinType == Inner || bufferNullGenerated) {
    +                bufferNullGenerated = false
    +                bufferedPosition = 0
    +                fetchStreamed()
    +              } else {
    +                found = true
    +              }
    +            } else {
    +              // mark as true so we don't generate null row for streamed 
row.
    +              bufferNullGenerated = true
    +              bitSet.set(bufferedPosition)
    +              found = true
    +            }
    +          }
    +          if (!found) {
    +            if (joinType == FullOuter && !streamNullGenerated) {
    +              streamNullGenerated = true
    +            }
    +            if (streamNullGenerated) {
    +              while (bufferedPosition < bufferedMatches.size && 
bitSet.get(bufferedPosition)) {
    +                bufferedPosition += 1
    +              }
    +              if (bufferedPosition < bufferedMatches.size) {
    +                found = true
    +              }
    +            }
    +          }
    +          if (!found) {
    +            stop = false
    +            bufferedMatches = null
    +          }
    +          found
    +        }
     
             override final def next(): InternalRow = {
               if (hasNext) {
    -            // we are using the buffered right rows and run down left 
iterator
    -            val joinedRow = joinRow(leftElement, 
rightMatches(rightPosition))
    -            rightPosition += 1
    -            if (rightPosition >= rightMatches.size) {
    -              rightPosition = 0
    -              fetchLeft()
    -              if (leftElement == null || keyOrdering.compare(leftKey, 
matchKey) != 0) {
    -                stop = false
    -                rightMatches = null
    +            if (isBufferEmpty(bufferedMatches)) {
    +              // we just found a row with no join match and we are here to 
produce a row
    +              // with this row and a standard null row from the other side.
    +              if (continueStreamed) {
    +                val joinedRow = smartJoinRow(streamedElement, 
bufferedNullRow)
    +                fetchStreamed()
    +                joinedRow
    +              } else {
    +                val joinedRow = smartJoinRow(streamedNullRow, 
bufferedElement)
    +                fetchBuffered()
    +                joinedRow
                   }
    +            } else {
    +              // we are using the buffered right rows and run down left 
iterator
    +              val joinedRow = if (streamNullGenerated) {
    +                val ret = smartJoinRow(streamedNullRow, 
bufferedMatches(bufferedPosition))
    +                bufferedPosition += 1
    +                ret
    +              } else {
    +                if (bufferedPosition == bufferedMatches.size && 
!bufferNullGenerated) {
    +                  val ret = smartJoinRow(streamedElement, bufferedNullRow)
    +                  bufferNullGenerated = true
    +                  ret
    +                } else {
    +                  val ret = smartJoinRow(streamedElement, 
bufferedMatches(bufferedPosition))
    +                  bufferedPosition += 1
    +                  ret
    +                }
    +              }
    +              found = false
    +              joinedRow
                 }
    -            joinedRow
               } else {
                 // no more result
                 throw new NoSuchElementException
               }
             }
     
    -        private def fetchLeft() = {
    -          if (leftIter.hasNext) {
    -            leftElement = leftIter.next()
    -            leftKey = leftKeyGenerator(leftElement)
    +        private def smartJoinRow(streamedRow: InternalRow, bufferedRow: 
InternalRow): InternalRow =
    +          joinType match {
    +            case RightOuter => joinRow(bufferedRow, streamedRow)
    +            case _ => joinRow(streamedRow, bufferedRow)
    +          }
    +
    +        private def fetchStreamed(): Unit = {
    +          if (streamedIter.hasNext) {
    +            streamedElement = streamedIter.next()
    +            streamedKey = streamedKeyGenerator(streamedElement)
               } else {
    -            leftElement = null
    +            streamedElement = null
               }
             }
     
    -        private def fetchRight() = {
    -          if (rightIter.hasNext) {
    -            rightElement = rightIter.next()
    -            rightKey = rightKeyGenerator(rightElement)
    +        private def fetchBuffered(): Unit = {
    +          if (bufferedIter.hasNext) {
    +            bufferedElement = bufferedIter.next()
    +            bufferedKey = bufferedKeyGenerator(bufferedElement)
               } else {
    -            rightElement = null
    +            bufferedElement = null
               }
             }
     
             private def initialize() = {
    -          fetchLeft()
    -          fetchRight()
    +          fetchStreamed()
    +          fetchBuffered()
             }
     
             /**
    -         * Searches the right iterator for the next rows that have matches 
in left side, and store
    -         * them in a buffer.
    +         * Searches the right iterator for the next rows that have matches 
in left side (only check
    --- End diff --
    
    This may be clarified slightly in my own SMJ patch, #7904.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to