Github user mengxr commented on a diff in the pull request: https://github.com/apache/spark/pull/21195#discussion_r185984894 --- Diff: mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala --- @@ -323,4 +324,44 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + val featuresColNameD = "array_double_features" + val featuresColNameF = "array_float_features" + val doubleUDF = udf { (features: Vector) => + val featureArray = Array.fill[Double](features.size)(0.0) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val floatUDF = udf { (features: Vector) => + val featureArray = Array.fill[Float](features.size)(0.0f) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) + .drop("features") + val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) + .drop("features") + assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) + assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) + + val ldaD = new LDA().setK(k).setOptimizer("online") + .setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) + val ldaF = new LDA().setK(k).setOptimizer("online"). + setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) + val modelD = ldaD.fit(newdatasetD) + val modelF = ldaF.fit(newdatasetF) + + // logLikelihood, logPerplexity + val llD = modelD.logLikelihood(newdatasetD) + val llF = modelF.logLikelihood(newdatasetF) + // assert(llD == llF) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + val lpD = modelD.logPerplexity(newdatasetD) + val lpF = modelF.logPerplexity(newdatasetF) + // assert(lpD == lpF) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) --- End diff -- ditto
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org