Repository: spark
Updated Branches:
  refs/heads/master ce2b056d3 -> 067afb4e9


[SPARK-10699] [ML] Support checkpointInterval can be disabled

Currently use can set ```checkpointInterval``` to specify how often should the 
cache be check-pointed. But we also need the function that users can disable 
it. This PR supports that users can disable checkpoint if user setting 
```checkpointInterval = -1```.
We also add documents for GBT ```cacheNodeIds``` to make users can understand 
more clearly about checkpoint.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #8820 from yanboliang/spark-10699.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/067afb4e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/067afb4e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/067afb4e

Branch: refs/heads/master
Commit: 067afb4e9bb227f159bcbc2aafafce9693303ea9
Parents: ce2b056
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Sep 23 16:41:42 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Sep 23 16:41:42 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/DecisionTreeClassifier.scala       | 1 -
 .../org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 6 +++---
 .../scala/org/apache/spark/ml/param/shared/sharedParams.scala  | 4 ++--
 .../main/scala/org/apache/spark/ml/recommendation/ALS.scala    | 2 +-
 .../main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 2 +-
 mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 4 ++--
 6 files changed, 9 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index a6f6d46..b0157f7 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.HasCheckpointInterval
 import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, 
TreeClassifierParams}
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}

http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8049d51..8cb6b54 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -56,9 +56,9 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + 
\"__output\"")),
-      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 
10 means that " +
-        "the cache will get checkpointed every 10 iterations.",
-        isValid = "ParamValidators.gtEq(1)"),
+      ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or 
" +
+        "disable checkpoint (-1). E.g. 10 means that the cache will get 
checkpointed " +
+        "every 10 iterations", isValid = "(interval: Int) => interval == -1 || 
interval >= 1"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", 
Some("true")),
       ParamDesc[String]("handleInvalid", "how to handle invalid entries. 
Options are skip (which " +
         "will filter out rows with bad values), or error (which will throw an 
errror). More " +

http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index aff47fc..e362521 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params {
 private[ml] trait HasCheckpointInterval extends Params {
 
   /**
-   * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will 
get checkpointed every 10 iterations..
+   * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 
10 means that the cache will get checkpointed every 10 iterations.
    * @group param
    */
-  final val checkpointInterval: IntParam = new IntParam(this, 
"checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache 
will get checkpointed every 10 iterations.", ParamValidators.gtEq(1))
+  final val checkpointInterval: IntParam = new IntParam(this, 
"checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint 
(-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", 
(interval: Int) => interval == -1 || interval >= 1)
 
   /** @group getParam */
   final def getCheckpointInterval: Int = $(checkpointInterval)

http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 7db8ad8..9a56a75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -561,7 +561,7 @@ object ALS extends Logging {
     var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
     var previousCheckpointFile: Option[String] = None
     val shouldCheckpoint: Int => Boolean = (iter) =>
-      sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+      sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % 
checkpointInterval == 0)
     val deletePreviousCheckpointFile: () => Unit = () =>
       previousCheckpointFile.foreach { file =>
         try {

http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index c5ad8df..1ee0113 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -122,7 +122,7 @@ private[spark] class NodeIdCache(
     rddUpdateCount += 1
 
     // Handle checkpointing if the directory is not None.
-    if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
+    if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % 
checkpointInterval) == 0) {
       // Let's see if we can delete previous checkpoints.
       var canDelete = true
       while (checkpointQueue.size > 1 && canDelete) {

http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 42e74ce..281ba6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.tree
 
-import org.apache.spark.ml.classification.ClassifierParams
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -87,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams 
with HasCheckpointI
   /**
    * If false, the algorithm will pass trees to executors to match instances 
with nodes.
    * If true, the algorithm will cache node IDs for each instance.
-   * Caching can speed up training of deeper trees.
+   * Caching can speed up training of deeper trees. Users can set how often 
should the
+   * cache be checkpointed or disable it by setting checkpointInterval.
    * (default = false)
    * @group expertParam
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to