Repository: spark Updated Branches: refs/heads/branch-2.0 b3ebecbb7 -> 240c42b28
[SPARK-16500][ML][MLLIB][OPTIMIZER] add LBFGS convergence warning for all used place in MLLib ## What changes were proposed in this pull request? Add warning_for the following case when LBFGS training not actually convergence: 1) LogisticRegression 2) AFTSurvivalRegression 3) LBFGS algorithm wrapper in mllib package ## How was this patch tested? N/A Author: WeichenXu <weichenxu...@outlook.com> Closes #14157 from WeichenXu123/add_lbfgs_convergence_warning_for_all_used_place. (cherry picked from commit 252d4f27f23b547777892bcea25a2cea62d8cbab) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/240c42b2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/240c42b2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/240c42b2 Branch: refs/heads/branch-2.0 Commit: 240c42b284b3f4bd302984fa51513c249f6d7648 Parents: b3ebecb Author: WeichenXu <weichenxu...@outlook.com> Authored: Thu Jul 14 09:11:04 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Thu Jul 14 09:11:14 2016 +0100 ---------------------------------------------------------------------- .../apache/spark/ml/classification/LogisticRegression.scala | 5 +++++ .../org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 5 +++++ .../main/scala/org/apache/spark/mllib/optimization/LBFGS.scala | 6 ++++++ 3 files changed, 16 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e157bde..4bab801 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -424,6 +424,11 @@ class LogisticRegression @Since("1.2.0") ( throw new SparkException(msg) } + if (!state.actuallyConverged) { + logWarning("LogisticRegression training fininshed but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + /* The coefficients are trained in the scaled space; we're converting them back to the original space. http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 7c51845..366448f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -245,6 +245,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S throw new SparkException(msg) } + if (!state.actuallyConverged) { + logWarning("AFTSurvivalRegression training fininshed but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + state.x.toArray.clone() } http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ec6ffe6..c61b2db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -212,6 +212,12 @@ object LBFGS extends Logging { state = states.next() } lossHistory += state.value + + if (!state.actuallyConverged) { + logWarning("LBFGS training fininshed but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + val weights = Vectors.fromBreeze(state.x) val lossHistoryArray = lossHistory.result() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org