Repository: spark Updated Branches: refs/heads/branch-2.0 c8a7c2305 -> 1d274455c
[SPARK-16241][ML] model loading backward compatibility for ml NaiveBayes ## What changes were proposed in this pull request? model loading backward compatibility for ml NaiveBayes ## How was this patch tested? existing ut and manual test for loading models saved by Spark 1.6. Author: zlpmichelle <zlpmiche...@gmail.com> Closes #13940 from zlpmichelle/naivebayes. (cherry picked from commit b30a2dc7c50bfb70bd2b57be70530a9a9fa94a7a) Signed-off-by: Yanbo Liang <yblia...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d274455 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d274455 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d274455 Branch: refs/heads/branch-2.0 Commit: 1d274455cfa45bc63aee6ecf8dbb1f170ee16af2 Parents: c8a7c23 Author: zlpmichelle <zlpmiche...@gmail.com> Authored: Thu Jun 30 00:50:14 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Jun 30 00:50:35 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/NaiveBayes.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1d274455/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 7c34031..c99ae30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -28,8 +28,9 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} /** * Params for Naive Bayes Classifiers. @@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org