This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 71183b2  [SPARK-24489][ML] Check for invalid input type of weight data 
in ml.PowerIterationClustering
71183b2 is described below

commit 71183b283343a99c6fa99a41268dae412598067f
Author: Shahid <shahidk...@gmail.com>
AuthorDate: Mon Jan 7 09:15:50 2019 -0800

    [SPARK-24489][ML] Check for invalid input type of weight data in 
ml.PowerIterationClustering
    
    ## What changes were proposed in this pull request?
    The test case will result the following failure. currently in ml.PIC, there 
is no check for the data type of weight column.
     ```
     test("invalid input types for weight") {
        val invalidWeightData = spark.createDataFrame(Seq(
          (0L, 1L, "a"),
          (2L, 3L, "b")
        )).toDF("src", "dst", "weight")
    
        val pic = new PowerIterationClustering()
          .setWeightCol("weight")
    
        val result = pic.assignClusters(invalidWeightData)
      }
    ```
    ```
    Job aborted due to stage failure: Task 0 in stage 8077.0 failed 1 times, 
most recent failure: Lost task 0.0 in stage 8077.0 (TID 882, localhost, 
executor driver): scala.MatchError: [0,1,null] (of class 
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
        at 
org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178)
        at 
org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178)
        at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
        at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
        at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
        at scala.collection.Iterator$class.foreach(Iterator.scala:893)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
        at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:107)
        at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:105)
        at 
org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:847)
    ```
    In this PR, added check types for weight column.
    ## How was this patch tested?
    UT added
    
    Please review http://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #21509 from shahidki31/testCasePic.
    
    Authored-by: Shahid <shahidk...@gmail.com>
    Signed-off-by: Holden Karau <hol...@pigscanfly.ca>
---
 .../spark/ml/clustering/PowerIterationClustering.scala    |  1 +
 .../ml/clustering/PowerIterationClusteringSuite.scala     | 15 +++++++++++++++
 2 files changed, 16 insertions(+)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index d9a330f..149e99d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -166,6 +166,7 @@ class PowerIterationClustering private[clustering] (
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
       lit(1.0)
     } else {
+      SchemaUtils.checkNumericType(dataset.schema, $(weightCol))
       col($(weightCol)).cast(DoubleType)
     }
 
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index 55b460f..0ba3ffa 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -145,6 +145,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite
     assert(msg.contains("Similarity must be nonnegative"))
   }
 
+  test("check for invalid input types of weight") {
+    val invalidWeightData = spark.createDataFrame(Seq(
+      (0L, 1L, "a"),
+      (2L, 3L, "b")
+    )).toDF("src", "dst", "weight")
+
+    val msg = intercept[IllegalArgumentException] {
+      new PowerIterationClustering()
+        .setWeightCol("weight")
+        .assignClusters(invalidWeightData)
+    }.getMessage
+    assert(msg.contains("requirement failed: Column weight must be of type 
numeric" +
+      " but was actually of type string."))
+  }
+
   test("test default weight") {
     val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to