Repository: spark Updated Branches: refs/heads/branch-1.4 bc397753c -> d4914647a
Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None" This reverts commit fcd013cf70e7890aa25a8fe3cb6c8b36bf0e1f04. Author: Yin Huai <yh...@databricks.com> Closes #10632 from yhuai/pythonStyle. (cherry picked from commit e5cde7ab11a43334fa01b1bb8904da5c0774bc62) Signed-off-by: Yin Huai <yh...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4914647 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4914647 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4914647 Branch: refs/heads/branch-1.4 Commit: d4914647ae5f351009ce02cc2bac52e4a8002694 Parents: bc39775 Author: Yin Huai <yh...@databricks.com> Authored: Wed Jan 6 22:03:31 2016 -0800 Committer: Yin Huai <yh...@databricks.com> Committed: Wed Jan 6 22:05:14 2016 -0800 ---------------------------------------------------------------------- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/tests.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d4914647/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 23e25c6..0f93782 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -255,7 +255,7 @@ class GaussianMixture(object): if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = list(initialModel.weights) + initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, http://git-wip-us.apache.org/repos/asf/spark/blob/d4914647/python/pyspark/mllib/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index a9aca3f..d883f6f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -310,18 +310,6 @@ class ListTests(MLlibTestCase): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEquals(round(c1, 7), round(c2, 7)) - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org