Repository: spark
Updated Branches:
  refs/heads/master 5a3533e77 -> 00c310133


[SPARK-15593][SQL] Add DataFrameWriter.foreach to allow the user consuming data 
in ContinuousQuery

## What changes were proposed in this pull request?

* Add DataFrameWriter.foreach to allow the user consuming data in 
ContinuousQuery
  * ForeachWriter is the interface for the user to consume partitions of data
* Add a type parameter T to DataFrameWriter

Usage
```Scala
val ds = spark.read....stream().as[String]
ds.....write
         .queryName(...)
        .option("checkpointLocation", ...)
        .foreach(new ForeachWriter[Int] {
          def open(partitionId: Long, version: Long): Boolean = {
             // prepare some resources for a partition
             // check `version` if possible and return `false` if this is a 
duplicated data to skip the data processing.
          }

          override def process(value: Int): Unit = {
              // process data
          }

          def close(errorOrNull: Throwable): Unit = {
             // release resources for a partition
             // check `errorOrNull` and handle the error if necessary.
          }
        })
```

## How was this patch tested?

New unit tests.

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #13342 from zsxwing/foreach.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/00c31013
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/00c31013
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/00c31013

Branch: refs/heads/master
Commit: 00c310133df4f3893dd90d801168c2ab9841b102
Parents: 5a3533e
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Fri Jun 10 00:11:46 2016 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Fri Jun 10 00:11:46 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameWriter.scala  | 150 ++++++++++++++-----
 .../scala/org/apache/spark/sql/Dataset.scala    |   2 +-
 .../org/apache/spark/sql/ForeachWriter.scala    | 105 +++++++++++++
 .../sql/execution/streaming/ForeachSink.scala   |  53 +++++++
 .../execution/streaming/ForeachSinkSuite.scala  | 141 +++++++++++++++++
 .../spark/sql/sources/BucketedReadSuite.scala   |   4 +-
 6 files changed, 413 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 1dd8818..32e2fdc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
 import org.apache.spark.sql.execution.datasources.{BucketSpec, 
CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
 import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, 
StreamExecution}
+import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, 
ProcessingTime, Trigger}
 import org.apache.spark.util.Utils
@@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
  *
  * @since 1.4.0
  */
-final class DataFrameWriter private[sql](df: DataFrame) {
+final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
+
+  private val df = ds.toDF()
 
   /**
    * Specifies the behavior when data or table already exists. Options include:
@@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def mode(saveMode: SaveMode): DataFrameWriter = {
+  def mode(saveMode: SaveMode): DataFrameWriter[T] = {
     // mode() is used for non-continuous queries
     // outputMode() is used for continuous queries
     assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def mode(saveMode: String): DataFrameWriter = {
+  def mode(saveMode: String): DataFrameWriter[T] = {
     // mode() is used for non-continuous queries
     // outputMode() is used for continuous queries
     assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def outputMode(outputMode: OutputMode): DataFrameWriter = {
+  def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
     assertStreaming("outputMode() can only be called on continuous queries")
     this.outputMode = outputMode
     this
@@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def outputMode(outputMode: String): DataFrameWriter = {
+  def outputMode(outputMode: String): DataFrameWriter[T] = {
     assertStreaming("outputMode() can only be called on continuous queries")
     this.outputMode = outputMode.toLowerCase match {
       case "append" =>
@@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def trigger(trigger: Trigger): DataFrameWriter = {
+  def trigger(trigger: Trigger): DataFrameWriter[T] = {
     assertStreaming("trigger() can only be called on continuous queries")
     this.trigger = trigger
     this
@@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def format(source: String): DataFrameWriter = {
+  def format(source: String): DataFrameWriter[T] = {
     this.source = source
     this
   }
@@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def option(key: String, value: String): DataFrameWriter = {
+  def option(key: String, value: String): DataFrameWriter[T] = {
     this.extraOptions += (key -> value)
     this
   }
@@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Boolean): DataFrameWriter = option(key, 
value.toString)
+  def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, 
value.toString)
 
   /**
    * Adds an output option for the underlying data source.
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Long): DataFrameWriter = option(key, 
value.toString)
+  def option(key: String, value: Long): DataFrameWriter[T] = option(key, 
value.toString)
 
   /**
    * Adds an output option for the underlying data source.
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Double): DataFrameWriter = option(key, 
value.toString)
+  def option(key: String, value: Double): DataFrameWriter[T] = option(key, 
value.toString)
 
   /**
    * (Scala-specific) Adds output options for the underlying data source.
    *
    * @since 1.4.0
    */
-  def options(options: scala.collection.Map[String, String]): DataFrameWriter 
= {
+  def options(options: scala.collection.Map[String, String]): 
DataFrameWriter[T] = {
     this.extraOptions ++= options
     this
   }
@@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def options(options: java.util.Map[String, String]): DataFrameWriter = {
+  def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
     this.options(options.asScala)
     this
   }
@@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.4.0
    */
   @scala.annotation.varargs
-  def partitionBy(colNames: String*): DataFrameWriter = {
+  def partitionBy(colNames: String*): DataFrameWriter[T] = {
     this.partitioningColumns = Option(colNames)
     this
   }
@@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0
    */
   @scala.annotation.varargs
-  def bucketBy(numBuckets: Int, colName: String, colNames: String*): 
DataFrameWriter = {
+  def bucketBy(numBuckets: Int, colName: String, colNames: String*): 
DataFrameWriter[T] = {
     this.numBuckets = Option(numBuckets)
     this.bucketColumnNames = Option(colName +: colNames)
     this
@@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0
    */
   @scala.annotation.varargs
-  def sortBy(colName: String, colNames: String*): DataFrameWriter = {
+  def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
     this.sortColumnNames = Option(colName +: colNames)
     this
   }
@@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def queryName(queryName: String): DataFrameWriter = {
+  def queryName(queryName: String): DataFrameWriter[T] = {
     assertStreaming("queryName() can only be called on continuous queries")
     this.extraOptions += ("queryName" -> queryName)
     this
@@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
       val queryName =
         extraOptions.getOrElse(
           "queryName", throw new AnalysisException("queryName must be 
specified for memory sink"))
-      val checkpointLocation = extraOptions.get("checkpointLocation").map { 
userSpecified =>
-        new Path(userSpecified).toUri.toString
-      }.orElse {
-        val checkpointConfig: Option[String] =
-          df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)
-
-        checkpointConfig.map { location =>
-          new Path(location, queryName).toUri.toString
-        }
-      }.getOrElse {
+      val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = 
false).getOrElse {
         Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
       }
 
@@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
           className = source,
           options = extraOptions.toMap,
           partitionColumns = normalizedParCols.getOrElse(Nil))
-
       val queryName = extraOptions.getOrElse("queryName", 
StreamExecution.nextName)
-      val checkpointLocation = extraOptions.get("checkpointLocation")
-        .orElse {
-          df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
-            new Path(l, queryName).toUri.toString
-          }
-        }.getOrElse {
-          throw new AnalysisException("checkpointLocation must be specified 
either " +
-            "through option() or SQLConf")
-        }
-
       df.sparkSession.sessionState.continuousQueryManager.startQuery(
         queryName,
-        checkpointLocation,
+        getCheckpointLocation(queryName, failIfNotSet = true).get,
         df,
         dataSource.createSink(outputMode),
         outputMode,
@@ -401,6 +383,94 @@ final class DataFrameWriter private[sql](df: DataFrame) {
   }
 
   /**
+   * :: Experimental ::
+   * Starts the execution of the streaming query, which will continually send 
results to the given
+   * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be 
used to send the data
+   * generated by the [[DataFrame]]/[[Dataset]] to an external system. The 
returned The returned
+   * [[ContinuousQuery]] object can be used to interact with the stream.
+   *
+   * Scala example:
+   * {{{
+   *   datasetOfString.write.foreach(new ForeachWriter[String] {
+   *
+   *     def open(partitionId: Long, version: Long): Boolean = {
+   *       // open connection
+   *     }
+   *
+   *     def process(record: String) = {
+   *       // write string to connection
+   *     }
+   *
+   *     def close(errorOrNull: Throwable): Unit = {
+   *       // close the connection
+   *     }
+   *   })
+   * }}}
+   *
+   * Java example:
+   * {{{
+   *  datasetOfString.write().foreach(new ForeachWriter<String>() {
+   *
+   *    @Override
+   *    public boolean open(long partitionId, long version) {
+   *      // open connection
+   *    }
+   *
+   *    @Override
+   *    public void process(String value) {
+   *      // write string to connection
+   *    }
+   *
+   *    @Override
+   *    public void close(Throwable errorOrNull) {
+   *      // close the connection
+   *    }
+   *  });
+   * }}}
+   *
+   * @since 2.0.0
+   */
+  @Experimental
+  def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
+    assertNotBucketed("foreach")
+    assertStreaming(
+      "foreach() can only be called on streaming Datasets/DataFrames.")
+
+    val queryName = extraOptions.getOrElse("queryName", 
StreamExecution.nextName)
+    val sink = new 
ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
+    df.sparkSession.sessionState.continuousQueryManager.startQuery(
+      queryName,
+      getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
+        Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
+      },
+      df,
+      sink,
+      outputMode,
+      trigger)
+  }
+
+  /**
+   * Returns the checkpointLocation for a query. If `failIfNotSet` is `true` 
but the checkpoint
+   * location is not set, [[AnalysisException]] will be thrown. If 
`failIfNotSet` is `false`, `None`
+   * will be returned if the checkpoint location is not set.
+   */
+  private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): 
Option[String] = {
+    val checkpointLocation = extraOptions.get("checkpointLocation").map { 
userSpecified =>
+      new Path(userSpecified).toUri.toString
+    }.orElse {
+      df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
+        new Path(location, queryName).toUri.toString
+      }
+    }
+    if (failIfNotSet && checkpointLocation.isEmpty) {
+      throw new AnalysisException("checkpointLocation must be specified either 
" +
+        """through option("checkpointLocation", ...) or """ +
+        s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", 
...)""")
+    }
+    checkpointLocation
+  }
+
+  /**
    * Inserts the content of the [[DataFrame]] to the specified table. It 
requires that
    * the schema of the [[DataFrame]] is the same as the schema of the table.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 162524a..16bbf30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def write: DataFrameWriter = new DataFrameWriter(toDF())
+  def write: DataFrameWriter[T] = new DataFrameWriter[T](this)
 
   /**
    * Returns the content of the Dataset as a Dataset of JSON strings.

http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
new file mode 100644
index 0000000..09f0742
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.streaming.ContinuousQuery
+
+/**
+ * :: Experimental ::
+ * A class to consume data generated by a [[ContinuousQuery]]. Typically this 
is used to send the
+ * generated data to external systems. Each partition will use a new 
deserialized instance, so you
+ * usually should do all the initialization (e.g. opening a connection or 
initiating a transaction)
+ * in the `open` method.
+ *
+ * Scala example:
+ * {{{
+ *   datasetOfString.write.foreach(new ForeachWriter[String] {
+ *
+ *     def open(partitionId: Long, version: Long): Boolean = {
+ *       // open connection
+ *     }
+ *
+ *     def process(record: String) = {
+ *       // write string to connection
+ *     }
+ *
+ *     def close(errorOrNull: Throwable): Unit = {
+ *       // close the connection
+ *     }
+ *   })
+ * }}}
+ *
+ * Java example:
+ * {{{
+ *  datasetOfString.write().foreach(new ForeachWriter<String>() {
+ *
+ *    @Override
+ *    public boolean open(long partitionId, long version) {
+ *      // open connection
+ *    }
+ *
+ *    @Override
+ *    public void process(String value) {
+ *      // write string to connection
+ *    }
+ *
+ *    @Override
+ *    public void close(Throwable errorOrNull) {
+ *      // close the connection
+ *    }
+ *  });
+ * }}}
+ * @since 2.0.0
+ */
+@Experimental
+abstract class ForeachWriter[T] extends Serializable {
+
+  /**
+   * Called when starting to process one partition of new data in the 
executor. The `version` is
+   * for data deduplication when there are failures. When recovering from a 
failure, some data may
+   * be generated multiple times but they will always have the same version.
+   *
+   * If this method finds using the `partitionId` and `version` that this 
partition has already been
+   * processed, it can return `false` to skip the further data processing. 
However, `close` still
+   * will be called for cleaning up resources.
+   *
+   * @param partitionId the partition id.
+   * @param version a unique id for data deduplication.
+   * @return `true` if the corresponding partition and version id should be 
processed. `false`
+   *         indicates the partition should be skipped.
+   */
+  def open(partitionId: Long, version: Long): Boolean
+
+  /**
+   * Called to process the data in the executor side. This method will be 
called only when `open`
+   * returns `true`.
+   */
+  def process(value: T): Unit
+
+  /**
+   * Called when stopping to process one partition of new data in the executor 
side. This is
+   * guaranteed to be called either `open` returns `true` or `false`. However,
+   * `close` won't be called in the following cases:
+   *  - JVM crashes without throwing a `Throwable`
+   *  - `open` throws a `Throwable`.
+   *
+   * @param errorOrNull the error thrown during processing data or null if 
there was no error.
+   */
+  def close(errorOrNull: Throwable): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
new file mode 100644
index 0000000..14b9b1c
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+
+/**
+ * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the 
contract defined by
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @tparam T The expected type of the sink.
+ */
+class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with 
Serializable {
+
+  override def addBatch(batchId: Long, data: DataFrame): Unit = {
+    data.as[T].foreachPartition { iter =>
+      if (writer.open(TaskContext.getPartitionId(), batchId)) {
+        var isFailed = false
+        try {
+          while (iter.hasNext) {
+            writer.process(iter.next())
+          }
+        } catch {
+          case e: Throwable =>
+            isFailed = true
+            writer.close(e)
+        }
+        if (!isFailed) {
+          writer.close(null)
+        }
+      } else {
+        writer.close(null)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
new file mode 100644
index 0000000..e1fb3b9
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ForeachSinkSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfter {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  test("foreach") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(2).write
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .foreach(new TestForeachWriter())
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      val expectedEventsForPartition0 = Seq(
+        ForeachSinkSuite.Open(partition = 0, version = 0),
+        ForeachSinkSuite.Process(value = 1),
+        ForeachSinkSuite.Process(value = 3),
+        ForeachSinkSuite.Close(None)
+      )
+      val expectedEventsForPartition1 = Seq(
+        ForeachSinkSuite.Open(partition = 1, version = 0),
+        ForeachSinkSuite.Process(value = 2),
+        ForeachSinkSuite.Process(value = 4),
+        ForeachSinkSuite.Close(None)
+      )
+
+      val allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 2)
+      assert {
+        allEvents === Seq(expectedEventsForPartition0, 
expectedEventsForPartition1) ||
+          allEvents === Seq(expectedEventsForPartition1, 
expectedEventsForPartition0)
+      }
+      query.stop()
+    }
+  }
+
+  test("foreach with error") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(1).write
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .foreach(new TestForeachWriter() {
+          override def process(value: Int): Unit = {
+            super.process(value)
+            throw new RuntimeException("error")
+          }
+        })
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      val allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 1)
+      assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version 
= 0))
+      assert(allEvents(0)(1) ===  ForeachSinkSuite.Process(value = 1))
+      val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
+      assert(errorEvent.error.get.isInstanceOf[RuntimeException])
+      assert(errorEvent.error.get.getMessage === "error")
+      query.stop()
+    }
+  }
+}
+
+/** A global object to collect events in the executor */
+object ForeachSinkSuite {
+
+  trait Event
+
+  case class Open(partition: Long, version: Long) extends Event
+
+  case class Process[T](value: T) extends Event
+
+  case class Close(error: Option[Throwable]) extends Event
+
+  private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
+
+  def addEvents(events: Seq[Event]): Unit = {
+    _allEvents.add(events)
+  }
+
+  def allEvents(): Seq[Seq[Event]] = {
+    _allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
+  }
+
+  def clear(): Unit = {
+    _allEvents.clear()
+  }
+}
+
+/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
+class TestForeachWriter extends ForeachWriter[Int] {
+  ForeachSinkSuite.clear()
+
+  private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
+
+  override def open(partitionId: Long, version: Long): Boolean = {
+    events += ForeachSinkSuite.Open(partition = partitionId, version = version)
+    true
+  }
+
+  override def process(value: Int): Unit = {
+    events += ForeachSinkSuite.Process(value)
+  }
+
+  override def close(errorOrNull: Throwable): Unit = {
+    events += ForeachSinkSuite.Close(error = Option(errorOrNull))
+    ForeachSinkSuite.addEvents(events)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/00c31013/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index bab0092..fc01ff3 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils 
with TestHiveSinglet
       shuffleLeft: Boolean,
       shuffleRight: Boolean): Unit = {
     withTable("bucketed_table1", "bucketed_table2") {
-      def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): 
DataFrameWriter = {
+      def withBucket(
+          writer: DataFrameWriter[Row],
+          bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
         bucketSpec.map { spec =>
           writer.bucketBy(
             spec.numBuckets,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to