Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/21081#discussion_r181841503 --- Diff: mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala --- @@ -194,6 +195,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + val featuresColName = "array_model_features" + + val arrayUDF = udf { (features: Vector) => + features.toArray + } + val newdataset = dataset.withColumn(featuresColName, arrayUDF(col("features")) ) + + val kmeans = new KMeans() + .setFeaturesCol(featuresColName) + + assert(kmeans.getK === 2) + assert(kmeans.getFeaturesCol === featuresColName) + assert(kmeans.getPredictionCol === "prediction") + assert(kmeans.getMaxIter === 20) + assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) + assert(kmeans.getInitSteps === 2) + assert(kmeans.getTol === 1e-4) + assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) + val model = kmeans.setMaxIter(1).fit(newdataset) + + MLTestingUtils.checkCopyAndUids(kmeans, model) --- End diff -- You don't need this test here
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org