Repository: spark
Updated Branches:
  refs/heads/master a8a5cd24e -> f32a84250


http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
index 894786c..368c460 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 trait StateStoreMetricsTest extends StreamTest {
 
   def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
-    AssertOnQuery { q =>
+    AssertOnQuery(s"Check total state rows = $total, updated state rows = 
$updated") { q =>
       val progressWithData = q.recentProgress.filter(_.numInputRows > 
0).lastOption.get
       assert(
         progressWithData.stateOperators.map(_.numRowsTotal) === total,

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
new file mode 100644
index 0000000..533e116
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter}
+import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreProviderId}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+
+class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with 
BeforeAndAfter {
+
+  before {
+    SparkSession.setActiveSession(spark)  // set this before force 
initializing 'joinExec'
+    spark.streams.stateStoreCoordinator   // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+  }
+
+  import testImplicits._
+  test("stream stream inner join on non-time column") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue")
+    val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue")
+    val joined = df1.join(df2, "key")
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckAnswer(),
+      AddData(input2, 1, 10),       // 1 arrived on input1 first, then input2, 
should join
+      CheckLastBatch((1, 2, 3)),
+      AddData(input1, 10),          // 10 arrived on input2 first, then 
input1, should join
+      CheckLastBatch((10, 20, 30)),
+      AddData(input2, 1),           // another 1 in input2 should join with 1 
input1
+      CheckLastBatch((1, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 1), // multiple 1s should be kept in state causing 
multiple (1, 2, 3)
+      CheckLastBatch((1, 2, 3), (1, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 100),
+      AddData(input2, 100),
+      CheckLastBatch((100, 200, 300))
+    )
+  }
+
+  test("stream stream inner join on windows - without watermark") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp", 
('value * 2) as "leftValue")
+      .select('key, window('timestamp, "10 second"), 'leftValue)
+
+    val df2 = input2.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp",
+        ('value * 3) as "rightValue")
+      .select('key, window('timestamp, "10 second"), 'rightValue)
+
+    val joined = df1.join(df2, Seq("key", "window"))
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckLastBatch(),
+      AddData(input2, 1),
+      CheckLastBatch((1, 10, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 25),
+      CheckLastBatch(),
+      StopStream,
+      StartStream(),
+      AddData(input2, 25),
+      CheckLastBatch((25, 30, 50, 75)),
+      AddData(input1, 1),
+      CheckLastBatch((1, 10, 2, 3)),      // State for 1 still around as there 
is no watermark
+      StopStream,
+      StartStream(),
+      AddData(input1, 5),
+      CheckLastBatch(),
+      AddData(input2, 5),
+      CheckLastBatch((5, 10, 10, 15))     // No filter by any watermark
+    )
+  }
+
+  test("stream stream inner join on windows - with watermark") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp", 
('value * 2) as "leftValue")
+      .withWatermark("timestamp", "10 seconds")
+      .select('key, window('timestamp, "10 second"), 'leftValue)
+
+    val df2 = input2.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp",
+        ('value * 3) as "rightValue")
+      .select('key, window('timestamp, "10 second"), 'rightValue)
+
+    val joined = df1.join(df2, Seq("key", "window"))
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckAnswer(),
+      assertNumStateRows(total = 1, updated = 1),
+
+      AddData(input2, 1),
+      CheckLastBatch((1, 10, 2, 3)),
+      assertNumStateRows(total = 2, updated = 1),
+      StopStream,
+      StartStream(),
+
+      AddData(input1, 25),
+      CheckLastBatch(), // since there is only 1 watermark operator, the 
watermark should be 15
+      assertNumStateRows(total = 3, updated = 1),
+
+      AddData(input2, 25),
+      CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows 
having window=[0,10]
+      assertNumStateRows(total = 2, updated = 1),
+      StopStream,
+      StartStream(),
+
+      AddData(input2, 1),
+      CheckLastBatch(),       // Should not join as < 15 removed
+      assertNumStateRows(total = 2, updated = 0),  // row not add as 1 < state 
key watermark = 15
+
+      AddData(input1, 5),
+      CheckLastBatch(),       // Should not join or add to state as < 15 got 
filtered by watermark
+      assertNumStateRows(total = 2, updated = 0)
+    )
+  }
+
+  test("stream stream inner join with time range - with watermark - one side 
condition") {
+    import org.apache.spark.sql.functions._
+
+    val leftInput = MemoryStream[(Int, Int)]
+    val rightInput = MemoryStream[(Int, Int)]
+
+    val df1 = leftInput.toDF.toDF("leftKey", "time")
+      .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) 
as "leftValue")
+      .withWatermark("leftTime", "10 seconds")
+
+    val df2 = rightInput.toDF.toDF("rightKey", "time")
+      .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 
3) as "rightValue")
+      .withWatermark("rightTime", "10 seconds")
+
+    val joined =
+      df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - 
interval 5 seconds"))
+        .select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int"))
+
+    testStream(joined)(
+      AddData(leftInput, (1, 5)),
+      CheckAnswer(),
+      AddData(rightInput, (1, 11)),
+      CheckLastBatch((1, 5, 11)),
+      AddData(rightInput, (1, 10)),
+      CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less 
than rightTime 10 - 5
+      assertNumStateRows(total = 3, updated = 1),
+
+      // Increase event time watermark to 20s by adding data with time = 30s 
on both inputs
+      AddData(leftInput, (1, 3), (1, 30)),
+      CheckLastBatch((1, 3, 10), (1, 3, 11)),
+      assertNumStateRows(total = 5, updated = 2),
+      AddData(rightInput, (0, 30)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 6, updated = 1),
+
+      // event time watermark:    max event time - 10   ==>   30 - 10 = 20
+      // right side state constraint:    20 < leftTime < rightTime - 5   ==>   
rightTime > 25
+
+      // Run another batch with event time = 25 to clear right state where 
rightTime <= 25
+      AddData(rightInput, (0, 30)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 5, updated = 1),  // removed (1, 11) and (1, 
10), added (0, 30)
+
+      // New data to right input should match with left side (1, 3) and (1, 
5), as left state should
+      // not be cleared. But rows rightTime <= 20 should be filtered due to 
event time watermark and
+      // state rows with rightTime <= 25 should be removed from state.
+      // (1, 20) ==> filtered by event time watermark = 20
+      // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not 
added to state
+      //             as state watermark = 25
+      // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added 
to state
+      AddData(rightInput, (1, 20), (1, 21), (1, 28)),
+      CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)),
+      assertNumStateRows(total = 6, updated = 1),
+
+      // New data to left input with leftTime <= 20 should be filtered due to 
event time watermark
+      AddData(leftInput, (1, 20), (1, 21)),
+      CheckLastBatch((1, 21, 28)),
+      assertNumStateRows(total = 7, updated = 1)
+    )
+  }
+
+  test("stream stream inner join with time range - with watermark - two side 
conditions") {
+    import org.apache.spark.sql.functions._
+
+    val leftInput = MemoryStream[(Int, Int)]
+    val rightInput = MemoryStream[(Int, Int)]
+
+    val df1 = leftInput.toDF.toDF("leftKey", "time")
+      .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) 
as "leftValue")
+      .withWatermark("leftTime", "20 seconds")
+
+    val df2 = rightInput.toDF.toDF("rightKey", "time")
+      .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 
3) as "rightValue")
+      .withWatermark("rightTime", "30 seconds")
+
+    val condition = expr(
+      "leftKey = rightKey AND " +
+        "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + 
interval 5 seconds")
+
+    // This translates to leftTime <= rightTime + 5 seconds AND leftTime >= 
rightTime - 10 seconds
+    // So given leftTime, rightTime has to be BETWEEN leftTime - 5 seconds AND 
leftTime + 10 seconds
+    //
+    //  =============== * ======================== * 
============================== * ==> leftTime
+    //                  |                          |                           
     |
+    //     |<---- 5s -->|<------ 10s ------>|      |<------ 10s ------>|<---- 
5s -->|
+    //     |                                |                          |
+    //  == * ============================== * =========>============== * 
===============> rightTime
+    //
+    // E.g.
+    //      if rightTime = 60, then it matches only leftTime = [50, 65]
+    //      if leftTime = 20, then it match only with rightTime = [15, 30]
+    //
+    // State value predicates
+    //   left side:
+    //     values allowed:  leftTime >= rightTime - 10s   ==>   leftTime > 
eventTimeWatermark - 10
+    //     drop state where leftTime < eventTime - 10
+    //   right side:
+    //     values allowed:  rightTime >= leftTime - 5s   ==>   rightTime > 
eventTimeWatermark - 5
+    //     drop state where rightTime < eventTime - 5
+
+    val joined =
+      df1.join(df2, condition).select('leftKey, 'leftTime.cast("int"), 
'rightTime.cast("int"))
+
+    testStream(joined)(
+      // If leftTime = 20, then it match only with rightTime = [15, 30]
+      AddData(leftInput, (1, 20)),
+      CheckAnswer(),
+      AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 
31)),
+      CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)),
+      assertNumStateRows(total = 7, updated = 6),
+
+      // If rightTime = 60, then it matches only leftTime = [50, 65]
+      AddData(rightInput, (1, 60)),
+      CheckLastBatch(),                // matches with nothing on the left
+      AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)),
+      CheckLastBatch((1, 50, 60), (1, 65, 60)),
+      assertNumStateRows(total = 12, updated = 4),
+
+      // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - 
delay 30 = 30) = 30
+      // Left state value watermark = 30 - 10 = slightly less than 20 (since 
condition has <=)
+      //    Should drop < 20 from left, i.e., none
+      // Right state value watermark = 30 - 5 = slightly less than 25 (since 
condition has <=)
+      //    Should drop < 25 from the right, i.e., 14 and 15
+      AddData(leftInput, (1, 30), (1, 31)),     // 30 should not be processed 
or added to stat
+      CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)),
+      assertNumStateRows(total = 11, updated = 1),  // 12 - 2 removed + 1 added
+
+      // Advance the watermark
+      AddData(rightInput, (1, 80)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 12, updated = 1),
+
+      // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - 
delay 30 = 50) = 46
+      // Left state value watermark = 46 - 10 = slightly less than 36 (since 
condition has <=)
+      //    Should drop < 36 from left, i.e., 20, 31 (30 was not added)
+      // Right state value watermark = 46 - 5 = slightly less than 41 (since 
condition has <=)
+      //    Should drop < 41 from the right, i.e., 25, 26, 30, 31
+      AddData(rightInput, (1, 50)),
+      CheckLastBatch((1, 49, 50), (1, 50, 50)),
+      assertNumStateRows(total = 7, updated = 1)  // 12 - 6 removed + 1 added
+    )
+  }
+
+  testQuietly("stream stream inner join without equality predicate") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as 
"leftValue")
+    val df2 = input2.toDF.select('value as "rightKey", ('value * 3) as 
"rightValue")
+    val joined = df1.join(df2, expr("leftKey < rightKey"))
+    val e = intercept[Exception] {
+      val q = joined.writeStream.format("memory").queryName("test").start()
+      input1.addData(1)
+      q.awaitTermination(10000)
+    }
+    assert(e.toString.contains("Stream stream joins without equality predicate 
is not supported"))
+  }
+
+  testQuietly("extract watermark from time condition") {
+    val attributesToFindConstraintFor = Seq(
+      AttributeReference("leftTime", TimestampType)(),
+      AttributeReference("leftOther", IntegerType)())
+    val metadataWithWatermark = new MetadataBuilder()
+      .putLong(EventTimeWatermark.delayKey, 1000)
+      .build()
+    val attributesWithWatermark = Seq(
+      AttributeReference("rightTime", TimestampType, metadata = 
metadataWithWatermark)(),
+      AttributeReference("rightOther", IntegerType)())
+
+    def watermarkFrom(
+        conditionStr: String,
+        rightWatermark: Option[Long] = Some(10000)): Option[Long] = {
+      val conditionExpr = Some(conditionStr).map { str =>
+        val plan =
+          Filter(
+            spark.sessionState.sqlParser.parseExpression(str),
+            LogicalRDD(
+              attributesToFindConstraintFor ++ attributesWithWatermark,
+              spark.sparkContext.emptyRDD)(spark))
+        plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition
+      }
+      StreamingSymmetricHashJoinHelper.getStateValueWatermark(
+        AttributeSet(attributesToFindConstraintFor), 
AttributeSet(attributesWithWatermark),
+        conditionExpr, rightWatermark)
+    }
+
+    // Test comparison directionality. E.g. if leftTime < rightTime and 
rightTime > watermark,
+    // then cannot define constraint on leftTime.
+    assert(watermarkFrom("leftTime > rightTime") === Some(10000))
+    assert(watermarkFrom("leftTime >= rightTime") === Some(9999))
+    assert(watermarkFrom("leftTime < rightTime") === None)
+    assert(watermarkFrom("leftTime <= rightTime") === None)
+    assert(watermarkFrom("rightTime > leftTime") === None)
+    assert(watermarkFrom("rightTime >= leftTime") === None)
+    assert(watermarkFrom("rightTime < leftTime") === Some(10000))
+    assert(watermarkFrom("rightTime <= leftTime") === Some(9999))
+
+    // Test type conversions
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") 
=== None)
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
DOUBLE)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
FLOAT)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS 
STRING)") === None)
+
+    // Test with timestamp type + calendar interval on either side of equation
+    // Note: timestamptype and calendar interval don't commute, so less valid 
combinations to test.
+    assert(watermarkFrom("leftTime > rightTime + interval 1 second") === 
Some(11000))
+    assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === 
Some(8000))
+    assert(watermarkFrom("leftTime > rightTime - interval 3 second") === 
Some(7000))
+    assert(watermarkFrom("rightTime < leftTime - interval 3 second") === 
Some(13000))
+    assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 
3 second")
+      === Some(12000))
+
+    // Test with casted long type + constants on either side of equation
+    // Note: long type and constants commute, so more combinations to test.
+    // -- Constants on the right
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 
1") === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 
1") === Some(9000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 
1 second) AS LONG)")
+      === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS 
LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS 
LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 
2") === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) 
> 0.1")
+      === Some(10100))
+    assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) 
+ 0.2")
+      === Some(10200))
+    // -- Constants on the left
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS 
LONG)") === Some(8000))
+    assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS 
LONG)") === Some(9000))
+    assert(watermarkFrom("CAST((leftTime  + interval 3 second) AS LONG) > 
CAST(rightTime AS LONG)")
+      === Some(7000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS 
LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS 
LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 
> 0")
+      === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 
0.1 > 0")
+      === Some(10100))
+    // -- Constants on both sides, mixed types
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS 
LONG) + 1")
+      === Some(13000))
+
+    // Test multiple conditions, should return minimum watermark
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 2 seconds") ===
+      Some(7000))  // first condition wins
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 4 seconds") ===
+      Some(6000))  // second condition wins
+
+    // Test invalid comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None)      
// non-time attributes
+    assert(watermarkFrom("leftOther > rightOther") === None)                  
// non-time attributes
+    assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") 
=== Some(10000))
+    assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None)  
// non-time attributes
+    assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) 
// month not allowed
+
+    // Test static comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000))
+  }
+
+  test("locality preferences of StateStoreAwareZippedRDD") {
+    import StreamingSymmetricHashJoinHelper._
+
+    withTempDir { tempDir =>
+      val queryId = UUID.randomUUID
+      val opId = 0
+      val path = Utils.createDirectory(tempDir.getAbsolutePath, 
Random.nextString(10)).toString
+      val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L)
+
+      implicit val sqlContext = spark.sqlContext
+      val coordinatorRef = sqlContext.streams.stateStoreCoordinator
+      val numPartitions = 5
+      val storeNames = Seq("name1", "name2")
+
+      val partitionAndStoreNameToLocation = {
+        for (partIndex <- 0 until numPartitions; storeName <- storeNames) 
yield {
+          (partIndex, storeName) -> s"host-$partIndex-$storeName"
+        }
+      }.toMap
+      partitionAndStoreNameToLocation.foreach { case ((partIndex, storeName), 
hostName) =>
+        val providerId = StateStoreProviderId(stateInfo, partIndex, storeName)
+        coordinatorRef.reportActiveInstance(providerId, hostName, 
s"exec-$hostName")
+        require(
+          coordinatorRef.getLocation(providerId) ===
+            Some(ExecutorCacheTaskLocation(hostName, 
s"exec-$hostName").toString))
+      }
+
+      val rdd1 = spark.sparkContext.makeRDD(1 to 10, numPartitions)
+      val rdd2 = spark.sparkContext.makeRDD((1 to 10).map(_.toString), 
numPartitions)
+      val rdd = rdd1.stateStoreAwareZipPartitions(rdd2, stateInfo, storeNames, 
coordinatorRef) {
+        (left, right) => left.zip(right)
+      }
+      require(rdd.partitions.length === numPartitions)
+      for (partIndex <- 0 until numPartitions) {
+        val expectedLocations = storeNames.map { storeName =>
+          val hostName = partitionAndStoreNameToLocation((partIndex, 
storeName))
+          ExecutorCacheTaskLocation(hostName, s"exec-$hostName").toString
+        }.toSet
+        assert(rdd.preferredLocations(rdd.partitions(partIndex)).toSet === 
expectedLocations)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to