Github user JoshRosen commented on a diff in the pull request: https://github.com/apache/spark/pull/7904#discussion_r36487371 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala --- @@ -56,117 +52,247 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + // TODO(josh): there is an existing bug here: this should also check whether unsafe mode + // is enabled. also, the default for self.codegenEnabled looks inconsistent to me. + codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(schema) + } + + // TODO(josh): this will need to change once we use an Unsafe row joiner + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: CompactBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + leftIter, + rightIter + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow - } else { - // no more result - throw new NoSuchElementException - } - } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + override final def hasNext: Boolean = + (currentMatchIdx != -1 && currentMatchIdx < currentRightMatches.length) || fetchNext() + + private[this] def fetchNext(): Boolean = { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBuildMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + true } else { - leftElement = null + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + false } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } else { - rightElement = null + override def next(): InternalRow = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + fetchNext() } + val joinedRow = joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + joinedRow } + } + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * The streamed input is the left side of a left outer join or the right side of a right outer join. + * + * // todo(josh): scaladoc + * @param streamedKeyGenerator + * @param buildKeyGenerator + * @param keyOrdering + * @param streamedIter + * @param buildIter + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + buildKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: Iterator[InternalRow], + buildIter: Iterator[InternalRow]) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var buildRow: InternalRow = _ + private[this] var buildRowKey: InternalRow = _ + /** The join key for the rows buffered in `buildMatches`, or null if `buildMatches` is empty */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the build side of the join. This is null if there are no matches */ + private[this] var buildMatches: CompactBuffer[InternalRow] = _ + + // Initialization (note: do _not_ want to advance streamed here). + advanceBuild() + + // --- Public methods --------------------------------------------------------------------------- - private def initialize() = { - fetchLeft() - fetchRight() + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBuildMatches]] can be called to produce the join results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (buildRow == null) { + // The streamed row's join key does not match the current batch of build rows and there are no + // more rows to read from the build iterator, so there can be no more matches. + false + } else { + // Advance both the streamed and build iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, buildRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else if (buildRowKey.anyNull) { + advanceBuild() + } else { + comp = keyOrdering.compare(streamedRowKey, buildRowKey) + if (comp > 0) advanceBuild() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && buildRow != null && comp != 0) + if (streamedRow == null || buildRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + false + } else { + // The streamed and build rows have matching join keys, so walk through the build iterator + // to buffer all matching rows. + assert(comp == 0) + bufferMatchingBuildRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the build input with matching + * keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBuildMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (advancedStreamed()) { + if (streamedRowKey.anyNull) { + // Since at least one join column is null, the streamed row has no matches. + matchJoinKey = null + buildMatches = null + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + buildMatches = null + if (buildRow != null) { + // The build iterator could still contain matching rows, so we'll need to walk through it + // until we either find matches or pass where they would be found. + var comp = + if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) + while (comp > 0 && advanceBuild()) { + comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) + } + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingBuildRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input, then we always return true + true + } else { + // End of streamed input, hence no more results. + false } } + + def getStreamedRow: InternalRow = streamedRow + def getBuildMatches: CompactBuffer[InternalRow] = buildMatches + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.hasNext) { + streamedRow = streamedIter.next() + streamedRowKey = streamedKeyGenerator(streamedRow) + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the build iterator and compute the new row's join key. + * @return true if the build iterator returned a row and false otherwise. + */ + private def advanceBuild(): Boolean = { + if (buildIter.hasNext) { + buildRow = buildIter.next() + buildRowKey = buildKeyGenerator(buildRow) + true + } else { + buildRow = null + buildRowKey = null + false + } + } + + /** + * Called when the streamed and build join keys match in order to buffer the matching build rows. + */ + private def bufferMatchingBuildRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(buildRowKey != null) + assert(!buildRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, buildRowKey) == 0) + matchJoinKey = streamedRowKey.copy() + buildMatches = new CompactBuffer[InternalRow] --- End diff -- Good question; I think that there's a time / space trade-off at place here depending on how many of these buffers you expect to create. I haven't benchmarked this yet, but I'm fairly certain that re-creating the `CompactBuffer` here is going to have an allocation hit compared to just using an ArrayBuffer, especially given the costs of the extra branches when adding elements, etc.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org