zhengruifeng commented on a change in pull request #31985:
URL: https://github.com/apache/spark/pull/31985#discussion_r602997003



##########
File path: 
mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
##########
@@ -1863,21 +1899,125 @@ class LogisticRegressionSuite extends MLTest with 
DefaultReadWriteTest {
       0.0, 0.0, 0.0, 0.09064661,
       -0.1144333, 0.3204703, -0.1621061, -0.2308192,
       0.0, -0.4832131, 0.0, 0.0), isTransposed = true)
-    val interceptsRStd = Vectors.dense(-0.72638218, -0.01737265, 0.74375484)
+    val interceptsRStd = Vectors.dense(-0.69265374, -0.2260274, 0.9186811)
     val coefficientsR = new DenseMatrix(3, 4, Array(
       0.0, 0.0, 0.01641412, 0.03570376,
       -0.05110822, 0.0, -0.21595670, -0.16162836,
       0.0, 0.0, 0.0, 0.0), isTransposed = true)
     val interceptsR = Vectors.dense(-0.44707756, 0.75180900, -0.3047314)
 
-    assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.05)
-    assert(model1.interceptVector ~== interceptsRStd relTol 0.1)
+    assert(model1.coefficientMatrix ~== coefficientsRStd absTol 1e-3)
+    assert(model1.interceptVector ~== interceptsRStd relTol 1e-3)
     assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps)
-    assert(model2.coefficientMatrix ~== coefficientsR absTol 0.02)
-    assert(model2.interceptVector ~== interceptsR relTol 0.1)
+    assert(model2.coefficientMatrix ~== coefficientsR absTol 1e-3)
+    assert(model2.interceptVector ~== interceptsR relTol 1e-3)
     assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
   }
 
+  test("SPARK-34860: multinomial logistic regression with intercept, with 
small var") {

Review comment:
       master does not pass this newly add testsuite:
   
   ```
       // scalastyle:off println
       println("R")
       println(interceptsR)
       println(coefficientsR)
   
       println()
       println("model1")
       println(model1.interceptVector)
       println(model1.coefficientMatrix)
   
       println()
       println("model2")
       println(model2.interceptVector)
       println(model2.coefficientMatrix)
   
       println()
       println("R2")
       println(interceptsR2)
       println(coefficientsR2)
   
       println()
       println("model3")
       println(model3.interceptVector)
       println(model3.coefficientMatrix)
       // scalastyle:on println
   ```
   
   
   
   
   this PR:
   ```
   R
   [2.91748298,-17.510746,14.59326301]
   0.21755977  0.01647541   0.16507778  -0.1401668   
   -0.24436    0.7564655    -0.2955698  1.3262009    
   0.02680026  -0.77294095  0.13049206  -1.18603411  
   model1
   [2.933958199942738,-17.543164024163175,14.609205824220437]
   0.21812136899052606   0.015486127035160564  0.16560717317181253  
-0.14189621394905397  
   -0.2454895541210769   0.7584152697648037    -0.2966285999752721  
1.3296192946128171    
   0.027368185130550855  -0.7739013967999642   0.13102142680345957  
-1.187723080663763    
   model2
   [2.933958199942738,-17.543164024163175,14.609205824220437]
   0.21812136899052606   0.015486127035160564  0.16560717317181253  
-0.14189621394905397  
   -0.2454895541210769   0.7584152697648037    -0.2966285999752721  
1.3296192946128171    
   0.027368185130550855  -0.7739013967999642   0.13102142680345957  
-1.187723080663763    
   
   R2
   [1.751626027,-3.9297124987,2.178086472]
   0.019970169   0.079611293   0.003959452   0.110024399   
   -4.788494E-4  0.0010097453  -5.832701E-4  0.0           
   -0.01936999   -0.080851149  -0.003319687  -0.112435972  
   model3
   [1.7516587309368687,-3.9297178332916585,2.1780591023547897]
   0.019968543900064605    0.07960456424549685    0.0039592584764418055   
0.11002491382872195  
   -4.7805989516075794E-4  0.0010124410611496804  -5.830912612961964E-4   0.0   
               
   -0.01936890596857533    -0.08084716280475213   -0.0033195486718121834  
-0.1124344396230352  
   ```
   
   
   
   
   
   master:
   ```
   R
   [2.91748298,-17.510746,14.59326301]
   0.21755977  0.01647541   0.16507778  -0.1401668   
   -0.24436    0.7564655    -0.2955698  1.3262009    
   0.02680026  -0.77294095  0.13049206  -1.18603411  
   model1
   [3.2289115796175536,-3.8874667667006286,0.6585551870830749]
   0.21614280080869921   0.010853354751576538  0.16526956599746928  
-0.16826299113708829  
   -0.24226138413980347  0.766137782321547     -0.2961105375461299  
-0.01353727702893284  
   0.02611858333110428   -0.7769911370731234   0.13084097154866067  
0.18180026816602116   
   model2
   [3.2289115795385817,-3.8874667667014213,0.65855518716284]
   0.216142800347921     0.01085335149421333  0.1652695665789533   
-0.16826299025797364   
   -0.24226138429694594  0.7661377826486023   -0.2961105377075671  
-0.013537276769415511  
   0.026118583949024932  -0.7769911341428156  0.13084097112861381  
0.18180026702738916    
   
   
   R2
   [1.751626027,-3.9297124987,2.178086472]
   0.019970169   0.079611293   0.003959452   0.110024399   
   -4.788494E-4  0.0010097453  -5.832701E-4  0.0           
   -0.01936999   -0.080851149  -0.003319687  -0.112435972  
   model3
   [3.2372615840468177,-3.887368655600576,0.6501070715537581]
   0.019839138236381705    0.0794011365650966    0.0039488069038510374   
-0.028039490169532715  
   -4.7766310601985774E-4  0.001015028177410154  -5.837314941386912E-4   
0.003248397753479403   
   -0.019272077893351464   -0.0806696209827808   -0.0033080309178124827  
0.04470991419683469    
   ```
   
   We can see that new impl generate solution much more close to GLMNET.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to