Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r142530435
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala 
---
    @@ -470,3 +475,222 @@ class StreamingJoinSuite extends StreamTest with 
StateStoreMetricsTest with Befo
         }
       }
     }
    +
    +class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with BeforeAndAfter {
    +
    +  import testImplicits._
    +  import org.apache.spark.sql.functions._
    +
    +  before {
    +    SparkSession.setActiveSession(spark) // set this before force 
initializing 'joinExec'
    +    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
    +  }
    +
    +  after {
    +    StateStore.stop()
    +  }
    +
    +  private def setupStream(prefix: String, multiplier: Int): 
(MemoryStream[Int], DataFrame) = {
    +    val input = MemoryStream[Int]
    +    val df = input.toDF
    +      .select(
    +        'value as "key",
    +        'value.cast("timestamp") as s"${prefix}Time",
    +        ('value * multiplier) as s"${prefix}Value")
    +      .withWatermark(s"${prefix}Time", "10 seconds")
    +
    +    return (input, df)
    +  }
    +
    +  private def setupWindowedJoin(joinType: String):
    +  (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
    +    val (input1, df1) = setupStream("left", 2)
    +    val (input2, df2) = setupStream("right", 3)
    +    val windowed1 = df1.select('key, window('leftTime, "10 second"), 
'leftValue)
    +    val windowed2 = df2.select('key, window('rightTime, "10 second"), 
'rightValue)
    +    val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
    +      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
    +
    +    (input1, input2, joined)
    +  }
    +
    +  test("windowed left outer join") {
    +    val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer")
    +
    +    testStream(joined)(
    +      // Test inner part of the join.
    +      AddData(leftInput, 1, 2, 3, 4, 5),
    +      AddData(rightInput, 3, 4, 5, 6, 7),
    +      CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
    +      // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
    +      // nulls won't show up until the next batch after the watermark 
advances.
    +      AddData(leftInput, 21),
    +      AddData(rightInput, 22),
    +      CheckLastBatch(),
    +      assertNumStateRows(total = 12, updated = 2),
    +      AddData(leftInput, 22),
    +      CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 
4, null)),
    +      assertNumStateRows(total = 3, updated = 1)
    +    )
    +  }
    +
    +  test("windowed right outer join") {
    +    val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer")
    +
    +    testStream(joined)(
    +      // Test inner part of the join.
    +      AddData(leftInput, 1, 2, 3, 4, 5),
    +      AddData(rightInput, 3, 4, 5, 6, 7),
    +      CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
    +      // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
    +      // nulls won't show up until the next batch after the watermark 
advances.
    +      AddData(leftInput, 21),
    +      AddData(rightInput, 22),
    +      CheckLastBatch(),
    +      assertNumStateRows(total = 12, updated = 2),
    +      AddData(leftInput, 22),
    +      CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, 
null, 21)),
    +      assertNumStateRows(total = 3, updated = 1)
    +    )
    +  }
    +
    +  Seq(
    +    ("left_outer", Row(3, null, 5, null)),
    +    ("right_outer", Row(null, 2, null, 5))
    +  ).foreach { case (joinType: String, outerResult) =>
    +    test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range 
condition") {
    +      import org.apache.spark.sql.functions._
    +
    +      val leftInput = MemoryStream[(Int, Int)]
    +      val rightInput = MemoryStream[(Int, Int)]
    +
    +      val df1 = leftInput.toDF.toDF("leftKey", "time")
    +        .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey 
* 2) as "leftValue")
    +        .withWatermark("leftTime", "10 seconds")
    +
    +      val df2 = rightInput.toDF.toDF("rightKey", "time")
    +        .select('rightKey, 'time.cast("timestamp") as "rightTime", 
('rightKey * 3) as "rightValue")
    +        .withWatermark("rightTime", "10 seconds")
    +
    +      val joined =
    +        df1.join(
    +          df2,
    +          expr("leftKey = rightKey AND " +
    +            "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime 
+ interval 5 seconds"),
    +          joinType)
    +          .select('leftKey, 'rightKey, 'leftTime.cast("int"), 
'rightTime.cast("int"))
    +      testStream(joined)(
    +        AddData(leftInput, (1, 5), (3, 5)),
    +        CheckAnswer(),
    +        AddData(rightInput, (1, 10), (2, 5)),
    +        CheckLastBatch((1, 1, 5, 10)),
    +        AddData(rightInput, (1, 11)),
    +        CheckLastBatch(), // no match as left time is too low
    +        assertNumStateRows(total = 5, updated = 1),
    +
    +        // Increase event time watermark to 20s by adding data with time = 
30s on both inputs
    +        AddData(leftInput, (1, 7), (1, 30)),
    +        CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)),
    +        assertNumStateRows(total = 7, updated = 2),
    +        AddData(rightInput, (0, 30)),
    +        CheckLastBatch(),
    +        assertNumStateRows(total = 8, updated = 1),
    +        AddData(rightInput, (0, 30)),
    +        CheckLastBatch(outerResult),
    +        assertNumStateRows(total = 3, updated = 1)
    +      )
    +    }
    +  }
    +
    +  // When the join condition isn't true, the outer null rows must be 
generated, even if the join
    +  // keys themselves have a match.
    +  test("left outer join with non-key condition violated on left") {
    +    val (leftInput, simpleLeftDf) = setupStream("left", 2)
    +    val (rightInput, simpleRightDf) = setupStream("right", 3)
    +
    +    val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 
'leftValue)
    +    val right = simpleRightDf.select('key, window('rightTime, "10 
second"), 'rightValue)
    +
    +    val joined = left.join(
    +        right,
    +        left("key") === right("key") && left("window") === right("window") 
&&
    +            'leftValue > 20 && 'rightValue < 200,
    +        "left_outer")
    +      .select(left("key"), left("window.end").cast("long"), 'leftValue, 
'rightValue)
    +
    +    testStream(joined)(
    +      // leftValue <= 20 should generate outer join rows even though it 
matches right keys
    +      AddData(leftInput, 1, 2, 3),
    +      AddData(rightInput, 1, 2, 3),
    +      CheckLastBatch(),
    +      AddData(leftInput, 30),
    +      AddData(rightInput, 31),
    +      CheckLastBatch(),
    +      assertNumStateRows(total = 8, updated = 2),
    +      AddData(rightInput, 32),
    --- End diff --
    
    In fact, then you dont need the next unit test. you can reduce the number 
of tests very easily.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to