Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22270#discussion_r213945889
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala ---
    @@ -652,65 +653,66 @@ class ALSSuite extends MLTest with 
DefaultReadWriteTest with Logging {
       test("input type validation") {
         val spark = this.spark
         import spark.implicits._
    -
    -    // check that ALS can handle all numeric types for rating column
    -    // and user/item columns (when the user/item ids are within Int range)
    -    val als = new ALS().setMaxIter(1).setRank(1)
    -    Seq(("user", IntegerType), ("item", IntegerType), ("rating", 
FloatType)).foreach {
    -      case (colName, sqlType) =>
    -        checkNumericTypesALS(als, spark, colName, sqlType) {
    -          (ex, act) =>
    -            ex.userFactors.first().getSeq[Float](1) === 
act.userFactors.first().getSeq[Float](1)
    -        } { (ex, act, df, enc) =>
    -          val expected = ex.transform(df).selectExpr("prediction")
    -            .first().getFloat(0)
    -          testTransformerByGlobalCheckFunc(df, act, "prediction") {
    -            case rows: Seq[Row] =>
    -              expected ~== rows.head.getFloat(0) absTol 1e-6
    -          }(enc)
    -        }
    -    }
    -    // check user/item ids falling outside of Int range
    -    val big = Int.MaxValue.toLong + 1
    -    val small = Int.MinValue.toDouble - 1
    -    val df = Seq(
    -      (0, 0L, 0d, 1, 1L, 1d, 3.0),
    -      (0, big, small, 0, big, small, 2.0),
    -      (1, 1L, 1d, 0, 0L, 0d, 5.0)
    -    ).toDF("user", "user_big", "user_small", "item", "item_big", 
"item_small", "rating")
    -    val msg = "either out of Integer range or contained a fractional part"
    -    withClue("fit should fail when ids exceed integer range. ") {
    -      assert(intercept[SparkException] {
    -        als.fit(df.select(df("user_big").as("user"), df("item"), 
df("rating")))
    -      }.getCause.getMessage.contains(msg))
    -      assert(intercept[SparkException] {
    -        als.fit(df.select(df("user_small").as("user"), df("item"), 
df("rating")))
    -      }.getCause.getMessage.contains(msg))
    -      assert(intercept[SparkException] {
    -        als.fit(df.select(df("item_big").as("item"), df("user"), 
df("rating")))
    -      }.getCause.getMessage.contains(msg))
    -      assert(intercept[SparkException] {
    -        als.fit(df.select(df("item_small").as("item"), df("user"), 
df("rating")))
    -      }.getCause.getMessage.contains(msg))
    -    }
    -    withClue("transform should fail when ids exceed integer range. ") {
    -      val model = als.fit(df)
    -      def testTransformIdExceedsIntRange[A : Encoder](dataFrame: 
DataFrame): Unit = {
    +    withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") {
    --- End diff --
    
    Sure. :)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to