Repository: spark Updated Branches: refs/heads/master ce2b056d3 -> 067afb4e9
[SPARK-10699] [ML] Support checkpointInterval can be disabled Currently use can set ```checkpointInterval``` to specify how often should the cache be check-pointed. But we also need the function that users can disable it. This PR supports that users can disable checkpoint if user setting ```checkpointInterval = -1```. We also add documents for GBT ```cacheNodeIds``` to make users can understand more clearly about checkpoint. Author: Yanbo Liang <yblia...@gmail.com> Closes #8820 from yanboliang/spark-10699. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/067afb4e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/067afb4e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/067afb4e Branch: refs/heads/master Commit: 067afb4e9bb227f159bcbc2aafafce9693303ea9 Parents: ce2b056 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Sep 23 16:41:42 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Sep 23 16:41:42 2015 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/DecisionTreeClassifier.scala | 1 - .../org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 6 +++--- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 4 ++-- 6 files changed, 9 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index a6f6d46..b0157f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8049d51..8cb6b54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,9 +56,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + - "the cache will get checkpointed every 10 iterations.", - isValid = "ParamValidators.gtEq(1)"), + ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an errror). More " + http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index aff47fc..e362521 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params { private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 7db8ad8..9a56a75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -561,7 +561,7 @@ object ALS extends Logging { var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => - sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) val deletePreviousCheckpointFile: () => Unit = () => previousCheckpointFile.foreach { file => try { http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index c5ad8df..1ee0113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -122,7 +122,7 @@ private[spark] class NodeIdCache( rddUpdateCount += 1 // Handle checkpointing if the directory is not None. - if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) { + if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) { // Let's see if we can delete previous checkpoints. var canDelete = true while (checkpointQueue.size > 1 && canDelete) { http://git-wip-us.apache.org/repos/asf/spark/blob/067afb4e/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 42e74ce..281ba6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** * If false, the algorithm will pass trees to executors to match instances with nodes. * If true, the algorithm will cache node IDs for each instance. - * Caching can speed up training of deeper trees. + * Caching can speed up training of deeper trees. Users can set how often should the + * cache be checkpointed or disable it by setting checkpointInterval. * (default = false) * @group expertParam */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org