Repository: spark
Updated Branches:
  refs/heads/branch-2.0 b3ebecbb7 -> 240c42b28


[SPARK-16500][ML][MLLIB][OPTIMIZER] add LBFGS convergence warning for all used 
place in MLLib

## What changes were proposed in this pull request?

Add warning_for the following case when LBFGS training not actually convergence:

1) LogisticRegression
2) AFTSurvivalRegression
3) LBFGS algorithm wrapper in mllib package

## How was this patch tested?

N/A

Author: WeichenXu <weichenxu...@outlook.com>

Closes #14157 from 
WeichenXu123/add_lbfgs_convergence_warning_for_all_used_place.

(cherry picked from commit 252d4f27f23b547777892bcea25a2cea62d8cbab)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/240c42b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/240c42b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/240c42b2

Branch: refs/heads/branch-2.0
Commit: 240c42b284b3f4bd302984fa51513c249f6d7648
Parents: b3ebecb
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Thu Jul 14 09:11:04 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Jul 14 09:11:14 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/ml/classification/LogisticRegression.scala    | 5 +++++
 .../org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 5 +++++
 .../main/scala/org/apache/spark/mllib/optimization/LBFGS.scala | 6 ++++++
 3 files changed, 16 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index e157bde..4bab801 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -424,6 +424,11 @@ class LogisticRegression @Since("1.2.0") (
           throw new SparkException(msg)
         }
 
+        if (!state.actuallyConverged) {
+          logWarning("LogisticRegression training fininshed but the result " +
+            s"is not converged because: ${state.convergedReason.get.reason}")
+        }
+
         /*
            The coefficients are trained in the scaled space; we're converting 
them back to
            the original space.

http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 7c51845..366448f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -245,6 +245,11 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
         throw new SparkException(msg)
       }
 
+      if (!state.actuallyConverged) {
+        logWarning("AFTSurvivalRegression training fininshed but the result " +
+          s"is not converged because: ${state.convergedReason.get.reason}")
+      }
+
       state.x.toArray.clone()
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/240c42b2/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index ec6ffe6..c61b2db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -212,6 +212,12 @@ object LBFGS extends Logging {
       state = states.next()
     }
     lossHistory += state.value
+
+    if (!state.actuallyConverged) {
+      logWarning("LBFGS training fininshed but the result " +
+        s"is not converged because: ${state.convergedReason.get.reason}")
+    }
+
     val weights = Vectors.fromBreeze(state.x)
 
     val lossHistoryArray = lossHistory.result()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to