Github user hhbyyh commented on a diff in the pull request: https://github.com/apache/spark/pull/21248#discussion_r189466147 --- Diff: examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala --- @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.log4j.{Level, Logger} + +// $example on$ +import org.apache.spark.ml.clustering.PowerIterationClustering +// $example off$ +import org.apache.spark.sql.{DataFrame, Row, SparkSession} + + + /** + * An example demonstrating power iteration clustering. + * Run with + * {{{ + * bin/run-example ml.PowerIterationClusteringExample + * }}} + */ + +object PowerIterationClusteringExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + Logger.getRootLogger.setLevel(Level.WARN) + + // $example on$ + + // Generates data. + val radius1 = 1.0 + val numPoints1 = 5 + val radius2 = 4.0 + val numPoints2 = 20 + + val dataset = generatePICData(spark, radius1, radius2, numPoints1, numPoints2) + + // Trains a PIC model. + val model = new PowerIterationClustering(). + setK(2). + setInitMode("degree"). + setMaxIter(20) + + val prediction = model.transform(dataset).select("id", "prediction") + + // Shows the result. + // println("Cluster Assignment: ") + val result = prediction.collect().map { + row => (row(1), row(0)) + }.groupBy(_._1).mapValues(_.map(_._2)) + + result.foreach { + case (cluster, points) => println(s"$cluster -> [${points.mkString(",")}]") + } --- End diff -- This can be achieved by DataFrame API, groupBy.. collect_set
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org