Repository: spark Updated Branches: refs/heads/master 713e4f44e -> da936fbb7
[SPARK-10779] [PYSPARK] [MLLIB] Set initialModel for KMeans model in PySpark (spark.mllib) Provide initialModel param for pyspark.mllib.clustering.KMeans Author: Evan Chen <ch...@us.ibm.com> Closes #8967 from evanyc15/SPARK-10779-pyspark-mllib. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/da936fbb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/da936fbb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/da936fbb Branch: refs/heads/master Commit: da936fbb74b852d5c98286ce92522dc3efd6ad6c Parents: 713e4f4 Author: Evan Chen <ch...@us.ibm.com> Authored: Wed Oct 7 15:04:53 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Oct 7 15:04:53 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/api/python/PythonMLLibAPI.scala | 4 +++- python/pyspark/mllib/clustering.py | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/da936fbb/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 69ce7f5..21e5593 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -336,7 +336,8 @@ private[python] class PythonMLLibAPI extends Serializable { initializationMode: String, seed: java.lang.Long, initializationSteps: Int, - epsilon: Double): KMeansModel = { + epsilon: Double, + initialModel: java.util.ArrayList[Vector]): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) @@ -346,6 +347,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setEpsilon(epsilon) if (seed != null) kMeansAlg.setSeed(seed) + if (!initialModel.isEmpty()) kMeansAlg.setInitialModel(new KMeansModel(initialModel)) try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) http://git-wip-us.apache.org/repos/asf/spark/blob/da936fbb/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 900ade2..6964a45 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -90,6 +90,12 @@ class KMeansModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass + + >>> data = array([-383.1,-382.9, 28.7,31.2, 366.2,367.3]).reshape(3, 2) + >>> model = KMeans.train(sc.parallelize(data), 3, maxIterations=0, + ... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)])) + >>> model.clusterCenters + [array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])] """ def __init__(self, centers): @@ -144,10 +150,17 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", - seed=None, initializationSteps=5, epsilon=1e-4): + seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): """Train a k-means clustering model.""" + clusterInitialModel = [] + if initialModel is not None: + if not isinstance(initialModel, KMeansModel): + raise Exception("initialModel is of "+str(type(initialModel))+". It needs " + "to be of <type 'KMeansModel'>") + clusterInitialModel = [_convert_to_vector(c) for c in initialModel.clusterCenters] model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode, seed, initializationSteps, epsilon) + runs, initializationMode, seed, initializationSteps, epsilon, + clusterInitialModel) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org