Repository: spark
Updated Branches:
  refs/heads/master 1fad55968 -> 8e491af52


[SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain 
compatibility with the model stored before 2.0

## What changes were proposed in this pull request?
Revert change for NB Model's Load to maintain compatibility with the model 
stored before 2.0

## How was this patch tested?
local build

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #15313 from zhengruifeng/revert_save_load.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e491af5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e491af5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e491af5

Branch: refs/heads/master
Commit: 8e491af52930886cbe0c54e7d67add3796ddb15f
Parents: 1fad559
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Fri Sep 30 08:18:48 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Fri Sep 30 08:18:48 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/NaiveBayes.scala  | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8e491af5/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 0d652aa..6775745 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, 
ParamValidators}
 import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.Dataset
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit}
 import org.apache.spark.sql.types.DoubleType
 
@@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] 
{
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
-      val data = sparkSession.read.parquet(dataPath).select("pi", 
"theta").head()
-      val pi = data.getAs[Vector](0)
-      val theta = data.getAs[Matrix](1)
+      val data = sparkSession.read.parquet(dataPath)
+      val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
+      val Row(pi: Vector, theta: Matrix) = 
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+        .select("pi", "theta")
+        .head()
       val model = new NaiveBayesModel(metadata.uid, pi, theta)
 
       DefaultParamsReader.getAndSetParams(model, metadata)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to