Repository: spark
Updated Branches:
  refs/heads/master 0b7d4966c -> 0fc4aaa71


http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 85db051..6be94eb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with 
BeforeAndAfterAll {
 
@@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
   }
 
   test("versioning and immutability") {
-    quietly {
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        implicit val sqlContet = new SQLContext(sc)
-        val path = Utils.createDirectory(tempDir, 
Random.nextString(10)).toString
-        val increment = (store: StateStore, iter: Iterator[String]) => {
-          iter.foreach { s =>
-            store.update(
-              stringToRow(s), oldRow => {
-                val oldValue = oldRow.map(rowToInt).getOrElse(0)
-                intToRow(oldValue + 1)
-              })
-          }
-          store.commit()
-          store.iterator().map(rowsToStringInt)
-        }
-        val opId = 0
-        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      val sqlContext = new SQLContext(sc)
+      val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+      val opId = 0
+      val rdd1 =
+        makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+            sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+            increment)
+      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+      // Generate next version of stores
+      val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(increment)
+      assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+      // Make sure the previous RDD still has the same data.
+      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    }
+  }
 
-        // Generate next version of stores
-        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
-        assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+  test("recovering from files") {
+    val opId = 0
+    val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+    def makeStoreRDD(
+        sc: SparkContext,
+        seq: Seq[String],
+        storeVersion: Int): RDD[(String, Int)] = {
+      implicit val sqlContext = new SQLContext(sc)
+      makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion, keySchema, 
valueSchema)(increment)
+    }
 
-        // Make sure the previous RDD still has the same data.
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    // Generate RDDs and state store data
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      for (i <- 1 to 20) {
+        require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" 
-> i))
       }
     }
+
+    // With a new context, try using the earlier state store data
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+    }
   }
 
-  test("recovering from files") {
-    quietly {
-      val opId = 0
+  test("usage with iterators - only gets and only puts") {
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      implicit val sqlContext = new SQLContext(sc)
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+      val opId = 0
 
-      def makeStoreRDD(
-          sc: SparkContext,
-          seq: Seq[String],
-          storeVersion: Int): RDD[(String, Int)] = {
-        implicit val sqlContext = new SQLContext(sc)
-        makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion, keySchema, valueSchema)
+      // Returns an iterator of the incremented value made into the store
+      def iteratorOfPuts(store: StateStore, iter: Iterator[String]): 
Iterator[(String, Int)] = {
+        val resIterator = iter.map { s =>
+          val key = stringToRow(s)
+          val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+          val newValue = oldValue + 1
+          store.put(key, intToRow(newValue))
+          (s, newValue)
+        }
+        CompletionIterator[(String, Int), Iterator[(String, 
Int)]](resIterator, {
+          store.commit()
+        })
       }
 
-      // Generate RDDs and state store data
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        for (i <- 1 to 20) {
-          require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === 
Set("a" -> i))
+      def iteratorOfGets(
+          store: StateStore,
+          iter: Iterator[String]): Iterator[(String, Option[Int])] = {
+        iter.map { s =>
+          val key = stringToRow(s)
+          val value = store.get(key).map(rowToInt)
+          (s, value)
         }
       }
 
-      // With a new context, try using the earlier state store data
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 
21))
-      }
+      val rddOfGets1 = makeRDD(sc, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(iteratorOfGets)
+      assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" 
-> None))
+
+      val rddOfPuts = makeRDD(sc, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(iteratorOfPuts)
+      assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
+
+      val rddOfGets2 = makeRDD(sc, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(iteratorOfGets)
+      assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> 
Some(1), "c" -> None))
     }
   }
 
@@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
           coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
-        val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
   test("distributed test") {
     quietly {
       withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 
1024]"))) { sc =>
-        implicit val sqlContet = new SQLContext(sc)
+        implicit val sqlContext = new SQLContext(sc)
         val path = Utils.createDirectory(tempDir, 
Random.nextString(10)).toString
-        val increment = (store: StateStore, iter: Iterator[String]) => {
-          iter.foreach { s =>
-            store.update(
-              stringToRow(s), oldRow => {
-                val oldValue = oldRow.map(rowToInt).getOrElse(0)
-                intToRow(oldValue + 1)
-              })
-          }
-          store.commit()
-          store.iterator().map(rowsToStringInt)
-        }
         val opId = 0
-        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
-        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
 
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
-      store.update(
-        stringToRow(s), oldRow => {
-          val oldValue = oldRow.map(rowToInt).getOrElse(0)
-          intToRow(oldValue + 1)
-        })
+      val key = stringToRow(s)
+      val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+      store.put(key, intToRow(oldValue + 1))
     }
     store.commit()
     store.iterator().map(rowsToStringInt)

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 22b2f4f..0e5936d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     StateStore.stop()
   }
 
-  test("update, remove, commit, and all data iterator") {
+  test("get, put, remove, commit, and all data iterator") {
     val provider = newStoreProvider()
 
     // Verify state before starting a new set of updates
@@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     }
 
     // Verify state after updating
-    update(store, "a", 1)
+    put(store, "a", 1)
     intercept[IllegalStateException] {
       store.iterator()
     }
@@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     assert(provider.latestIterator().isEmpty)
 
     // Make updates, commit and then verify state
-    update(store, "b", 2)
-    update(store, "aa", 3)
+    put(store, "b", 2)
+    put(store, "aa", 3)
     remove(store, _.startsWith("a"))
     assert(store.commit() === 1)
 
@@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     val reloadedProvider = new HDFSBackedStateStoreProvider(
       store.id, keySchema, valueSchema, StateStoreConf.empty, new 
Configuration)
     val reloadedStore = reloadedProvider.getStore(1)
-    update(reloadedStore, "c", 4)
+    put(reloadedStore, "c", 4)
     assert(reloadedStore.commit() === 2)
     assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
     assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
@@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
   test("updates iterator with all combos of updates and removes") {
     val provider = newStoreProvider()
     var currentVersion: Int = 0
+
     def withStore(body: StateStore => Unit): Unit = {
       val store = provider.getStore(currentVersion)
       body(store)
@@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // New data should be seen in updates as value added, even if they had 
multiple updates
     withStore { store =>
-      update(store, "a", 1)
-      update(store, "aa", 1)
-      update(store, "aa", 2)
+      put(store, "a", 1)
+      put(store, "aa", 1)
+      put(store, "aa", 2)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 
2)))
       assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
@@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     // Multiple updates to same key should be collapsed in the updates as a 
single value update
     // Keys that have not been updated should not appear in the updates
     withStore { store =>
-      update(store, "a", 4)
-      update(store, "a", 6)
+      put(store, "a", 4)
+      put(store, "a", 6)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
       assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
@@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Keys added, updated and finally removed before commit should not appear 
in updates
     withStore { store =>
-      update(store, "b", 4)     // Added, finally removed
-      update(store, "bb", 5)    // Added, updated, finally removed
-      update(store, "bb", 6)
+      put(store, "b", 4)     // Added, finally removed
+      put(store, "bb", 5)    // Added, updated, finally removed
+      put(store, "bb", 6)
       remove(store, _.startsWith("b"))
       store.commit()
       assert(updatesToSet(store.updates()) === Set.empty)
@@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     // Removed, but re-added data should be seen in updates as a value update
     withStore { store =>
       remove(store, _.startsWith("a"))
-      update(store, "a", 10)
+      put(store, "a", 10)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Updated("a", 10), 
Removed("aa")))
       assert(rowsToSet(store.iterator()) === Set("a" -> 10))
@@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
   test("cancel") {
     val provider = newStoreProvider()
     val store = provider.getStore(0)
-    update(store, "a", 1)
+    put(store, "a", 1)
     store.commit()
     assert(rowsToSet(store.iterator()) === Set("a" -> 1))
 
     // cancelUpdates should not change the data in the files
     val store1 = provider.getStore(1)
-    update(store1, "b", 1)
-    store1.cancel()
+    put(store1, "b", 1)
+    store1.abort()
     assert(getDataFromFiles(provider) === Set("a" -> 1))
   }
 
@@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Prepare some data in the stoer
     val store = provider.getStore(0)
-    update(store, "a", 1)
+    put(store, "a", 1)
     assert(store.commit() === 1)
     assert(rowsToSet(store.iterator()) === Set("a" -> 1))
 
@@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Update store version with some data
     val store1 = provider.getStore(1)
-    update(store1, "b", 1)
+    put(store1, "b", 1)
     assert(store1.commit() === 2)
     assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
     assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
 
     // Overwrite the version with other data
     val store2 = provider.getStore(1)
-    update(store2, "c", 1)
+    put(store2, "c", 1)
     assert(store2.commit() === 2)
     assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
     assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
@@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     def updateVersionTo(targetVersion: Int): Unit = {
       for (i <- currentVersion + 1 to targetVersion) {
         val store = provider.getStore(currentVersion)
-        update(store, "a", i)
+        put(store, "a", i)
         store.commit()
         currentVersion += 1
       }
@@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
-      update(store, "a", i)
+      put(store, "a", i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
@@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     val provider = newStoreProvider(minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
-      update(store, "a", i)
+      put(store, "a", i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
@@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
       // Increase version of the store
       val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, 
storeConf, hadoopConf)
       assert(store0.version === 0)
-      update(store0, "a", 1)
+      put(store0, "a", 1)
       store0.commit()
 
       assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, 
hadoopConf).version == 1)
@@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
       val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, 
storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
-      update(store1, "a", 2)
+      put(store1, "a", 2)
       assert(store1.commit() === 2)
       assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
     }
@@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
           for (i <- 1 to 20) {
             val store = StateStore.get(
               storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
-            update(store, "a", i)
+            put(store, "a", i)
             store.commit()
           }
           eventually(timeout(10 seconds)) {
@@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     store.remove(row => condition(rowToString(row)))
   }
 
-  private def update(store: StateStore, key: String, value: Int): Unit = {
-    store.update(stringToRow(key), _ => intToRow(value))
+  private def put(store: StateStore, key: String, value: Int): Unit = {
+    store.put(stringToRow(key), intToRow(value))
+  }
+
+  private def get(store: StateStore, key: String): Option[Int] = {
+    store.get(stringToRow(key)).map(rowToInt)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
new file mode 100644
index 0000000..b63ce89
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+object FailureSinglton {
+  var firstTime = true
+}
+
+class StreamingAggregationSuite extends StreamTest with SharedSQLContext {
+
+  import testImplicits._
+
+  test("simple count") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*"))
+        .as[(Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 3),
+      CheckLastBatch((3, 1)),
+      AddData(inputData, 3, 2),
+      CheckLastBatch((3, 2), (2, 1)),
+      StopStream,
+      StartStream,
+      AddData(inputData, 3, 2, 1),
+      CheckLastBatch((3, 3), (2, 2), (1, 1)),
+      // By default we run in new tuple mode.
+      AddData(inputData, 4, 4, 4, 4),
+      CheckLastBatch((4, 4))
+    )
+  }
+
+  test("multiple keys") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value", $"value" + 1)
+        .agg(count("*"))
+        .as[(Int, Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 1, 2),
+      CheckLastBatch((1, 2, 1), (2, 3, 1)),
+      AddData(inputData, 1, 2),
+      CheckLastBatch((1, 2, 2), (2, 3, 2))
+    )
+  }
+
+  test("multiple aggregations") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*") as 'count)
+        .groupBy($"value" % 2)
+        .agg(sum($"count"))
+        .as[(Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 1, 2, 3, 4),
+      CheckLastBatch((0, 2), (1, 2)),
+      AddData(inputData, 1, 3, 5),
+      CheckLastBatch((1, 5))
+    )
+  }
+
+  testQuietly("midbatch failure") {
+    val inputData = MemoryStream[Int]
+    FailureSinglton.firstTime = true
+    val aggregated =
+      inputData.toDS()
+          .map { i =>
+            if (i == 4 && FailureSinglton.firstTime) {
+              FailureSinglton.firstTime = false
+              sys.error("injected failure")
+            }
+
+            i
+          }
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+    testStream(aggregated)(
+      StartStream,
+      AddData(inputData, 1, 2, 3, 4),
+      ExpectFailure[SparkException](),
+      StartStream,
+      CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
+    )
+  }
+
+  test("typed aggregators") {
+    def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+      new SumOf(f).toColumn
+
+    val inputData = MemoryStream[(String, Int)]
+    val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2))
+
+    testStream(aggregated)(
+      AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
+      CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 2bdb428..ff40c36 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) 
extends SessionState(ctx)
   /**
    * Planner that takes into account Hive-specific strategies.
    */
-  override lazy val planner: SparkPlanner = {
-    new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with 
HiveStrategies {
+  override def planner: SparkPlanner = {
+    new SparkPlanner(ctx.sparkContext, conf, 
experimentalMethods.extraStrategies)
+      with HiveStrategies {
       override val hiveContext = ctx
 
       override def strategies: Seq[Strategy] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to