This is an automated email from the ASF dual-hosted git repository. holden pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 71183b2 [SPARK-24489][ML] Check for invalid input type of weight data in ml.PowerIterationClustering 71183b2 is described below commit 71183b283343a99c6fa99a41268dae412598067f Author: Shahid <shahidk...@gmail.com> AuthorDate: Mon Jan 7 09:15:50 2019 -0800 [SPARK-24489][ML] Check for invalid input type of weight data in ml.PowerIterationClustering ## What changes were proposed in this pull request? The test case will result the following failure. currently in ml.PIC, there is no check for the data type of weight column. ``` test("invalid input types for weight") { val invalidWeightData = spark.createDataFrame(Seq( (0L, 1L, "a"), (2L, 3L, "b") )).toDF("src", "dst", "weight") val pic = new PowerIterationClustering() .setWeightCol("weight") val result = pic.assignClusters(invalidWeightData) } ``` ``` Job aborted due to stage failure: Task 0 in stage 8077.0 failed 1 times, most recent failure: Lost task 0.0 in stage 8077.0 (TID 882, localhost, executor driver): scala.MatchError: [0,1,null] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema) at org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178) at org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178) at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434) at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440) at scala.collection.Iterator$class.foreach(Iterator.scala:893) at scala.collection.AbstractIterator.foreach(Iterator.scala:1336) at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:107) at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:847) ``` In this PR, added check types for weight column. ## How was this patch tested? UT added Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21509 from shahidki31/testCasePic. Authored-by: Shahid <shahidk...@gmail.com> Signed-off-by: Holden Karau <hol...@pigscanfly.ca> --- .../spark/ml/clustering/PowerIterationClustering.scala | 1 + .../ml/clustering/PowerIterationClusteringSuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index d9a330f..149e99d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -166,6 +166,7 @@ class PowerIterationClustering private[clustering] ( val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { lit(1.0) } else { + SchemaUtils.checkNumericType(dataset.schema, $(weightCol)) col($(weightCol)).cast(DoubleType) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 55b460f..0ba3ffa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -145,6 +145,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(msg.contains("Similarity must be nonnegative")) } + test("check for invalid input types of weight") { + val invalidWeightData = spark.createDataFrame(Seq( + (0L, 1L, "a"), + (2L, 3L, "b") + )).toDF("src", "dst", "weight") + + val msg = intercept[IllegalArgumentException] { + new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(invalidWeightData) + }.getMessage + assert(msg.contains("requirement failed: Column weight must be of type numeric" + + " but was actually of type string.")) + } + test("test default weight") { val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org