Repository: spark Updated Branches: refs/heads/master 5f6943345 -> 3099c574c
http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 533e116..a6593b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -24,8 +24,9 @@ import scala.util.Random import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} @@ -35,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } - testQuietly("extract watermark from time condition") { - val attributesToFindConstraintFor = Seq( - AttributeReference("leftTime", TimestampType)(), - AttributeReference("leftOther", IntegerType)()) - val metadataWithWatermark = new MetadataBuilder() - .putLong(EventTimeWatermark.delayKey, 1000) - .build() - val attributesWithWatermark = Seq( - AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), - AttributeReference("rightOther", IntegerType)()) - - def watermarkFrom( - conditionStr: String, - rightWatermark: Option[Long] = Some(10000)): Option[Long] = { - val conditionExpr = Some(conditionStr).map { str => - val plan = - Filter( - spark.sessionState.sqlParser.parseExpression(str), - LogicalRDD( - attributesToFindConstraintFor ++ attributesWithWatermark, - spark.sparkContext.emptyRDD)(spark)) - plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition - } - StreamingSymmetricHashJoinHelper.getStateValueWatermark( - AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), - conditionExpr, rightWatermark) - } - - // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, - // then cannot define constraint on leftTime. - assert(watermarkFrom("leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) - assert(watermarkFrom("leftTime < rightTime") === None) - assert(watermarkFrom("leftTime <= rightTime") === None) - assert(watermarkFrom("rightTime > leftTime") === None) - assert(watermarkFrom("rightTime >= leftTime") === None) - assert(watermarkFrom("rightTime < leftTime") === Some(10000)) - assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) - - // Test type conversions - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) - - // Test with timestamp type + calendar interval on either side of equation - // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. - assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) - assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) - assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) - assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) - assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") - === Some(12000)) - - // Test with casted long type + constants on either side of equation - // Note: long type and constants commute, so more combinations to test. - // -- Constants on the right - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") - === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") - === Some(10100)) - assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") - === Some(10200)) - // -- Constants on the left - assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) - assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) - assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") - === Some(7000)) - assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") - === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") - === Some(10100)) - // -- Constants on both sides, mixed types - assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") - === Some(13000)) - - // Test multiple conditions, should return minimum watermark - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === - Some(7000)) // first condition wins - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === - Some(6000)) // second condition wins - - // Test invalid comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes - assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed - - // Test static comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) - } - test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -470,3 +366,189 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String): + (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult), + assertNumStateRows(total = 3, updated = 1) + ) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("left outer join with non-key condition violated on left") { + val (leftInput, simpleLeftDf) = setupStream("left", 2) + val (rightInput, simpleRightDf) = setupStream("right", 3) + + val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) + val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = left.join( + right, + left("key") === right("key") && left("window") === right("window") && + 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + // leftValue <= 10 should generate outer join rows even though it matches right keys + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 1, 2, 3), + CheckLastBatch(), + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 20), + CheckLastBatch( + Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows + AddData(leftInput, 40, 41), + AddData(rightInput, 40, 41), + CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), + AddData(leftInput, 70), + AddData(rightInput, 71), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 2), + AddData(rightInput, 70), + CheckLastBatch((70, 80, 140, 210)), + assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left + AddData(leftInput, 101, 102, 103), + AddData(rightInput, 101, 102, 103), + CheckLastBatch(), + AddData(leftInput, 1000), + AddData(rightInput, 1001), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 1000), + CheckLastBatch( + Row(1000, 1010, 2000, 3000), + Row(101, 110, 202, null), + Row(102, 110, 204, null), + Row(103, 110, 206, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } +} + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org