Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19327#discussion_r140905683 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala --- @@ -216,22 +232,70 @@ case class StreamingSymmetricHashJoinExec( } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + val filteredInnerOutputIter = (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction) + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => rightSideJoiner.containsKey(pair.key)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + filteredInnerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => leftSideJoiner.containsKey(pair.key)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + filteredInnerOutputIter ++ outerOutputIter + case _ => + throwBadJoinTypeException() + Iterator() + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // TODO: how to get this for removals as part of outer join? allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Iterator which must be consumed after output completion before committing. + // For outer joins, we've removed old state from the appropriate side inline while we + // produced the null rows. So we need to finish cleaning the other side. For inner joins --- End diff -- "appropriate side inline" does not make sense to me. Something like this would be better "For inner joins, we have to remove unnecessary state rows from both sides if possible. For outer joins, we have already removed unnecessary state rows from the outer side (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we have to remove unnecessary state rows from the other side (e.g., right side for the left outer join) if possible. In all cases, nothing needs to be outputted, hence the removal needs to be done greedily by immediately consuming the returned iterator."
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org