Repository: spark
Updated Branches:
  refs/heads/master 6e36d8d56 -> 13268a58f


[SPARK-22649][PYTHON][SQL] Adding localCheckpoint to Dataset API

## What changes were proposed in this pull request?

This change adds local checkpoint support to datasets and respective bind from 
Python Dataframe API.

If reliability requirements can be lowered to favor performance, as in cases of 
further quick transformations followed by a reliable save, localCheckpoints() 
fit very well.
Furthermore, at the moment Reliable checkpoints still incur double computation 
(see #9428)
In general it makes the API more complete as well.

## How was this patch tested?

Python land quick use case:

```python
>>> from time import sleep
>>> from pyspark.sql import types as T
>>> from pyspark.sql import functions as F

>>> def f(x):
    sleep(1)
    return x*2
   ...:

>>> df1 = spark.range(30, numPartitions=6)
>>> df2 = df1.select(F.udf(f, T.LongType())("id"))

>>> %time _ = df2.collect()
CPU times: user 7.79 ms, sys: 5.84 ms, total: 13.6 ms
Wall time: 12.2 s

>>> %time df3 = df2.localCheckpoint()
CPU times: user 2.38 ms, sys: 2.3 ms, total: 4.68 ms
Wall time: 10.3 s

>>> %time _ = df3.collect()
CPU times: user 5.09 ms, sys: 410 µs, total: 5.5 ms
Wall time: 148 ms

>>> sc.setCheckpointDir(".")
>>> %time df3 = df2.checkpoint()
CPU times: user 4.04 ms, sys: 1.63 ms, total: 5.67 ms
Wall time: 20.3 s
```

Author: Fernando Pereira <fernando.pere...@epfl.ch>

Closes #19805 from ferdonline/feature_dataset_localCheckpoint.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/13268a58
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/13268a58
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/13268a58

Branch: refs/heads/master
Commit: 13268a58f8f67bf994f0ad5076419774c45daeeb
Parents: 6e36d8d
Author: Fernando Pereira <fernando.pere...@epfl.ch>
Authored: Tue Dec 19 20:47:12 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 19 20:47:12 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 |  14 +++
 .../scala/org/apache/spark/sql/Dataset.scala    |  49 ++++++++-
 .../org/apache/spark/sql/DatasetSuite.scala     | 107 +++++++++++--------
 3 files changed, 121 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/13268a58/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 9864dc9..75395a7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -368,6 +368,20 @@ class DataFrame(object):
         jdf = self._jdf.checkpoint(eager)
         return DataFrame(jdf, self.sql_ctx)
 
+    @since(2.3)
+    def localCheckpoint(self, eager=True):
+        """Returns a locally checkpointed version of this Dataset. 
Checkpointing can be used to
+        truncate the logical plan of this DataFrame, which is especially 
useful in iterative
+        algorithms where the plan may grow exponentially. Local checkpoints 
are stored in the
+        executors using the caching subsystem and therefore they are not 
reliable.
+
+        :param eager: Whether to checkpoint this DataFrame immediately
+
+        .. note:: Experimental
+        """
+        jdf = self._jdf.localCheckpoint(eager)
+        return DataFrame(jdf, self.sql_ctx)
+
     @since(2.1)
     def withWatermark(self, eventTime, delayThreshold):
         """Defines an event time watermark for this :class:`DataFrame`. A 
watermark tracks a point

http://git-wip-us.apache.org/repos/asf/spark/blob/13268a58/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c34cf0a..ef00562 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -527,7 +527,7 @@ class Dataset[T] private[sql](
    */
   @Experimental
   @InterfaceStability.Evolving
-  def checkpoint(): Dataset[T] = checkpoint(eager = true)
+  def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = 
true)
 
   /**
    * Returns a checkpointed version of this Dataset. Checkpointing can be used 
to truncate the
@@ -540,9 +540,52 @@ class Dataset[T] private[sql](
    */
   @Experimental
   @InterfaceStability.Evolving
-  def checkpoint(eager: Boolean): Dataset[T] = {
+  def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, 
reliableCheckpoint = true)
+
+  /**
+   * Eagerly locally checkpoints a Dataset and return the new Dataset. 
Checkpointing can be
+   * used to truncate the logical plan of this Dataset, which is especially 
useful in iterative
+   * algorithms where the plan may grow exponentially. Local checkpoints are 
written to executor
+   * storage and despite potentially faster they are unreliable and may 
compromise job completion.
+   *
+   * @group basic
+   * @since 2.3.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def localCheckpoint(): Dataset[T] = checkpoint(eager = true, 
reliableCheckpoint = false)
+
+  /**
+   * Locally checkpoints a Dataset and return the new Dataset. Checkpointing 
can be used to truncate
+   * the logical plan of this Dataset, which is especially useful in iterative 
algorithms where the
+   * plan may grow exponentially. Local checkpoints are written to executor 
storage and despite
+   * potentially faster they are unreliable and may compromise job completion.
+   *
+   * @group basic
+   * @since 2.3.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
+    eager = eager,
+    reliableCheckpoint = false
+  )
+
+  /**
+   * Returns a checkpointed version of this Dataset.
+   *
+   * @param eager Whether to checkpoint this dataframe immediately
+   * @param reliableCheckpoint Whether to create a reliable checkpoint saved 
to files inside the
+   *                           checkpoint directory. If false creates a local 
checkpoint using
+   *                           the caching subsystem
+   */
+  private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): 
Dataset[T] = {
     val internalRdd = queryExecution.toRdd.map(_.copy())
-    internalRdd.checkpoint()
+    if (reliableCheckpoint) {
+      internalRdd.checkpoint()
+    } else {
+      internalRdd.localCheckpoint()
+    }
 
     if (eager) {
       internalRdd.count()

http://git-wip-us.apache.org/repos/asf/spark/blob/13268a58/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b02db77..bd1e7ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1156,67 +1156,82 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
   }
 
   Seq(true, false).foreach { eager =>
-    def testCheckpointing(testName: String)(f: => Unit): Unit = {
-      test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
-        withTempDir { dir =>
-          val originalCheckpointDir = spark.sparkContext.checkpointDir
-
-          try {
-            spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
+    Seq(true, false).foreach { reliable =>
+      def testCheckpointing(testName: String)(f: => Unit): Unit = {
+        test(s"Dataset.checkpoint() - $testName (eager = $eager, reliable = 
$reliable)") {
+          if (reliable) {
+            withTempDir { dir =>
+              val originalCheckpointDir = spark.sparkContext.checkpointDir
+
+              try {
+                spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
+                f
+              } finally {
+                // Since the original checkpointDir can be None, we need
+                // to set the variable directly.
+                spark.sparkContext.checkpointDir = originalCheckpointDir
+              }
+            }
+          } else {
+            // Local checkpoints dont require checkpoint_dir
             f
-          } finally {
-            // Since the original checkpointDir can be None, we need
-            // to set the variable directly.
-            spark.sparkContext.checkpointDir = originalCheckpointDir
           }
         }
       }
-    }
 
-    testCheckpointing("basic") {
-      val ds = spark.range(10).repartition('id % 2).filter('id > 
5).orderBy('id.desc)
-      val cp = ds.checkpoint(eager)
+      testCheckpointing("basic") {
+        val ds = spark.range(10).repartition('id % 2).filter('id > 
5).orderBy('id.desc)
+        val cp = if (reliable) ds.checkpoint(eager) else 
ds.localCheckpoint(eager)
 
-      val logicalRDD = cp.logicalPlan match {
-        case plan: LogicalRDD => plan
-        case _ =>
-          val treeString = cp.logicalPlan.treeString(verbose = true)
-          fail(s"Expecting a LogicalRDD, but got\n$treeString")
-      }
+        val logicalRDD = cp.logicalPlan match {
+          case plan: LogicalRDD => plan
+          case _ =>
+            val treeString = cp.logicalPlan.treeString(verbose = true)
+            fail(s"Expecting a LogicalRDD, but got\n$treeString")
+        }
 
-      val dsPhysicalPlan = ds.queryExecution.executedPlan
-      val cpPhysicalPlan = cp.queryExecution.executedPlan
+        val dsPhysicalPlan = ds.queryExecution.executedPlan
+        val cpPhysicalPlan = cp.queryExecution.executedPlan
 
-      assertResult(dsPhysicalPlan.outputPartitioning) { 
logicalRDD.outputPartitioning }
-      assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
+        assertResult(dsPhysicalPlan.outputPartitioning) {
+          logicalRDD.outputPartitioning
+        }
+        assertResult(dsPhysicalPlan.outputOrdering) {
+          logicalRDD.outputOrdering
+        }
 
-      assertResult(dsPhysicalPlan.outputPartitioning) { 
cpPhysicalPlan.outputPartitioning }
-      assertResult(dsPhysicalPlan.outputOrdering) { 
cpPhysicalPlan.outputOrdering }
+        assertResult(dsPhysicalPlan.outputPartitioning) {
+          cpPhysicalPlan.outputPartitioning
+        }
+        assertResult(dsPhysicalPlan.outputOrdering) {
+          cpPhysicalPlan.outputOrdering
+        }
 
-      // For a lazy checkpoint() call, the first check also materializes the 
checkpoint.
-      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+        // For a lazy checkpoint() call, the first check also materializes the 
checkpoint.
+        checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
 
-      // Reads back from checkpointed data and check again.
-      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
-    }
+        // Reads back from checkpointed data and check again.
+        checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+      }
 
-    testCheckpointing("should preserve partitioning information") {
-      val ds = spark.range(10).repartition('id % 2)
-      val cp = ds.checkpoint(eager)
+      testCheckpointing("should preserve partitioning information") {
+        val ds = spark.range(10).repartition('id % 2)
+        val cp = if (reliable) ds.checkpoint(eager) else 
ds.localCheckpoint(eager)
 
-      val agg = cp.groupBy('id % 2).agg(count('id))
+        val agg = cp.groupBy('id % 2).agg(count('id))
 
-      agg.queryExecution.executedPlan.collectFirst {
-        case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
-        case BroadcastExchangeExec(_, _: RDDScanExec) =>
-      }.foreach { _ =>
-        fail(
-          "No Exchange should be inserted above RDDScanExec since the 
checkpointed Dataset " +
-            "preserves partitioning information:\n\n" + agg.queryExecution
-        )
-      }
+        agg.queryExecution.executedPlan.collectFirst {
+          case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
+          case BroadcastExchangeExec(_, _: RDDScanExec) =>
+        }.foreach { _ =>
+          fail(
+            "No Exchange should be inserted above RDDScanExec since the 
checkpointed Dataset " +
+              "preserves partitioning information:\n\n" + agg.queryExecution
+          )
+        }
 
-      checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
+        checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to