Github user JoshRosen commented on a diff in the pull request: https://github.com/apache/spark/pull/5717#discussion_r36493358 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala --- @@ -62,111 +100,250 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val streamResults = streamedPlan.execute().map(_.copy()) + val bufferResults = bufferedPlan.execute().map(_.copy()) - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { + // standard null rows + val streamedNullRow = InternalRow.fromSeq(Seq.fill(streamedPlan.output.length)(null)) + val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 + private[this] var streamedElement: InternalRow = _ + private[this] var bufferedElement: InternalRow = _ + private[this] var streamedKey: InternalRow = _ + private[this] var bufferedKey: InternalRow = _ + private[this] var bufferedMatches: CompactBuffer[InternalRow] = _ + private[this] var bufferedPosition: Int = -1 private[this] var stop: Boolean = false private[this] var matchKey: InternalRow = _ + // when we do merge algorithm and find some not matched join key, there must be a side + // that do not have a corresponding match. So we need to mark which side it is. True means + // streamed side not have match, and False means the buffered side. Only set when needed. + private[this] var continueStreamed: Boolean = _ + private[this] var streamNullGenerated: Boolean = false + // Tracks if each element in bufferedMatches have a matched streamedElement. + private[this] var bitSet: BitSet = _ + // marks if the found result has been fetched. + private[this] var found: Boolean = false + private[this] var bufferNullGenerated: Boolean = false // initialize iterator initialize() - override final def hasNext: Boolean = nextMatchingPair() + override final def hasNext: Boolean = { + val matching = nextMatchingBlock() + if (matching && !isBufferEmpty(bufferedMatches)) { + // The buffer stores all rows that match key, but condition may not be matched. + // If none of rows in the buffer match condition, we'll fetch next matching block. + findNextInBuffer() || hasNext + } else { + matching + } + } + + /** + * Run down the current `bufferedMatches` to find rows that match conditions. + * If `joinType` is not `Inner`, we will use `bufferNullGenerated` to mark if + * we need to build a bufferedNullRow for result. + * If `joinType` is `FullOuter`, we will use `streamNullGenerated` to mark if + * a buffered element need to join with a streamedNullRow. + * The method can be called multiple times since `found` serves as a guardian. + */ + def findNextInBuffer(): Boolean = { + while (!found && streamedElement != null + && keyOrdering.compare(streamedKey, matchKey) == 0) { + while (bufferedPosition < bufferedMatches.size && !boundCondition( + joinRow(streamedElement, bufferedMatches(bufferedPosition)))) { + bufferedPosition += 1 + } + if (bufferedPosition == bufferedMatches.size) { + if (joinType == Inner || bufferNullGenerated) { + bufferNullGenerated = false + bufferedPosition = 0 + fetchStreamed() + } else { + found = true + } + } else { + // mark as true so we don't generate null row for streamed row. + bufferNullGenerated = true + bitSet.set(bufferedPosition) + found = true + } + } + if (!found) { + if (joinType == FullOuter && !streamNullGenerated) { + streamNullGenerated = true + } + if (streamNullGenerated) { + while (bufferedPosition < bufferedMatches.size && bitSet.get(bufferedPosition)) { + bufferedPosition += 1 + } + if (bufferedPosition < bufferedMatches.size) { + found = true + } + } + } + if (!found) { + stop = false + bufferedMatches = null + } + found + } override final def next(): InternalRow = { if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null + if (isBufferEmpty(bufferedMatches)) { + // we just found a row with no join match and we are here to produce a row + // with this row and a standard null row from the other side. + if (continueStreamed) { + val joinedRow = smartJoinRow(streamedElement, bufferedNullRow) + fetchStreamed() + joinedRow + } else { + val joinedRow = smartJoinRow(streamedNullRow, bufferedElement) + fetchBuffered() + joinedRow } + } else { + // we are using the buffered right rows and run down left iterator + val joinedRow = if (streamNullGenerated) { + val ret = smartJoinRow(streamedNullRow, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret + } else { + if (bufferedPosition == bufferedMatches.size && !bufferNullGenerated) { + val ret = smartJoinRow(streamedElement, bufferedNullRow) + bufferNullGenerated = true + ret + } else { + val ret = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret + } + } + found = false + joinedRow } - joinedRow } else { // no more result throw new NoSuchElementException } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow = + joinType match { + case RightOuter => joinRow(bufferedRow, streamedRow) + case _ => joinRow(streamedRow, bufferedRow) + } + + private def fetchStreamed(): Unit = { + if (streamedIter.hasNext) { + streamedElement = streamedIter.next() + streamedKey = streamedKeyGenerator(streamedElement) } else { - leftElement = null + streamedElement = null } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + private def fetchBuffered(): Unit = { + if (bufferedIter.hasNext) { + bufferedElement = bufferedIter.next() + bufferedKey = bufferedKeyGenerator(bufferedElement) } else { - rightElement = null + bufferedElement = null } } private def initialize() = { - fetchLeft() - fetchRight() + fetchStreamed() + fetchBuffered() } /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. + * Searches the right iterator for the next rows that have matches in left side (only check --- End diff -- This may be clarified slightly in my own SMJ patch, #7904.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org