Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19327#discussion_r142344009 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala --- @@ -470,3 +475,283 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("left stream batch outer join") { + val stream = MemoryStream[Int] + .toDF() + .withColumn("timestamp", 'value.cast("timestamp")) + .withWatermark("timestamp", "1 second") + val joined = + stream.join(Seq(1).toDF(), Seq("value"), "left_outer") + + // This test is in the suite just to confirm the validations below don't block this valid join. + // We don't need to check results, just that the join can happen. + testStream(joined)() + } + + test("left batch stream outer join") { + val stream = MemoryStream[Int] + .toDF() + .withColumn("timestamp", 'value.cast("timestamp")) + .withWatermark("timestamp", "1 second") + val joined = + Seq(1).toDF().join(stream, Seq("value"), "left_outer") + + val thrown = intercept[AnalysisException] { + testStream(joined)() + } + + assert(thrown.getMessage.contains( + "Left outer join with a streaming DataFrame/Dataset on the right and a static")) + } + + test("right stream batch outer join") { + val stream = MemoryStream[Int] + .toDF() + .withColumn("timestamp", 'value.cast("timestamp")) + .withWatermark("timestamp", "1 second") + val joined = + stream.join(Seq(1).toDF(), Seq("value"), "right_outer") + + val thrown = intercept[AnalysisException] { + testStream(joined)() + } + + assert(thrown.getMessage.contains( + "Right outer join with a streaming DataFrame/Dataset on the left and a static")) + } + + test("left outer join with no watermark") { + val joined = + MemoryStream[Int].toDF().join(MemoryStream[Int].toDF(), Seq("value"), "left_outer") + + val thrown = intercept[AnalysisException] { + testStream(joined)() + } + + assert(thrown.getMessage.contains( + "Stream-stream outer join between two streaming DataFrame/Datasets is not supported " + + "without a watermark")) + } + + test("right outer join with no watermark") { + val joined = + MemoryStream[Int].toDF().join(MemoryStream[Int].toDF(), Seq("value"), "right_outer") + + val thrown = intercept[AnalysisException] { + testStream(joined)() + } + + assert(thrown.getMessage.contains( + "Stream-stream outer join between two streaming DataFrame/Datasets is not supported " + + "without a watermark")) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult)) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("outer join with non-key condition violated on left") { --- End diff -- outer join -> left outer join
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org