Repository: spark
Updated Branches:
  refs/heads/master 8d6ef895e -> 625cfe09e


[SPARK-19733][ML] Removed unnecessary castings and refactored checked casts in 
ALS.

## What changes were proposed in this pull request?

The original ALS was performing unnecessary casting to the user and item ids 
because the protected checkedCast() method required a double. I removed the 
castings and refactored the method to receive Any and efficiently handle all 
permitted numeric values.

## How was this patch tested?

I tested it by running the unit-tests and by manually validating the result of 
checkedCast for various legal and illegal values.

Author: Vasilis Vryniotis <bbrinio...@datumbox.com>

Closes #17059 from datumbox/als_casting_fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/625cfe09
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/625cfe09
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/625cfe09

Branch: refs/heads/master
Commit: 625cfe09e673bfcb95e361ce19b534cf0a3c782c
Parents: 8d6ef89
Author: Vasilis Vryniotis <bbrinio...@datumbox.com>
Authored: Thu Mar 2 12:37:42 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu Mar 2 12:37:42 2017 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 31 +++++---
 .../spark/ml/recommendation/ALSSuite.scala      | 84 +++++++++++++++++---
 2 files changed, 95 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/625cfe09/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 04273a4..799e881 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -80,14 +80,24 @@ private[recommendation] trait ALSModelParams extends Params 
with HasPredictionCo
 
   /**
    * Attempts to safely cast a user/item id to an Int. Throws an exception if 
the value is
-   * out of integer range.
+   * out of integer range or contains a fractional part.
    */
-  protected val checkedCast = udf { (n: Double) =>
-    if (n > Int.MaxValue || n < Int.MinValue) {
-      throw new IllegalArgumentException(s"ALS only supports values in Integer 
range for columns " +
-        s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
-    } else {
-      n.toInt
+  protected[recommendation] val checkedCast = udf { (n: Any) =>
+    n match {
+      case v: Int => v // Avoid unnecessary casting
+      case v: Number =>
+        val intV = v.intValue
+        // Checks if number within Int range and has no fractional part.
+        if (v.doubleValue == intV) {
+          intV
+        } else {
+          throw new IllegalArgumentException(s"ALS only supports values in 
Integer range " +
+            s"and without fractional part for columns ${$(userCol)} and 
${$(itemCol)}. " +
+            s"Value $n was either out of Integer range or contained a 
fractional part that " +
+            s"could not be converted.")
+        }
+      case _ => throw new IllegalArgumentException(s"ALS only supports values 
in Integer range " +
+        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not 
numeric.")
     }
   }
 
@@ -288,9 +298,9 @@ class ALSModel private[ml] (
     }
     val predictions = dataset
       .join(userFactors,
-        checkedCast(dataset($(userCol)).cast(DoubleType)) === 
userFactors("id"), "left")
+        checkedCast(dataset($(userCol))) === userFactors("id"), "left")
       .join(itemFactors,
-        checkedCast(dataset($(itemCol)).cast(DoubleType)) === 
itemFactors("id"), "left")
+        checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
       .select(dataset("*"),
         predict(userFactors("features"), 
itemFactors("features")).as($(predictionCol)))
     getColdStartStrategy match {
@@ -491,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends 
Estimator[ALSModel]
 
     val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else 
lit(1.0f)
     val ratings = dataset
-      .select(checkedCast(col($(userCol)).cast(DoubleType)),
-        checkedCast(col($(itemCol)).cast(DoubleType)), r)
+      .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r)
       .rdd
       .map { row =>
         Rating(row.getInt(0), row.getInt(1), row.getFloat(2))

http://git-wip-us.apache.org/repos/asf/spark/blob/625cfe09/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index c9e7b50..c8228dd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
 import org.apache.spark.sql.{DataFrame, Row, SparkSession}
-import org.apache.spark.sql.types.{FloatType, IntegerType}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -205,6 +206,70 @@ class ALSSuite
     assert(decompressed.toSet === expected)
   }
 
+  test("CheckedCast") {
+    val checkedCast = new ALS().checkedCast
+    val df = spark.range(1)
+
+    withClue("Valid Integer Ids") {
+      df.select(checkedCast(lit(123))).collect()
+    }
+
+    withClue("Valid Long Ids") {
+      df.select(checkedCast(lit(1231L))).collect()
+    }
+
+    withClue("Valid Decimal Ids") {
+      df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect()
+    }
+
+    withClue("Valid Double Ids") {
+      df.select(checkedCast(lit(123.0))).collect()
+    }
+
+    val msg = "either out of Integer range or contained a fractional part"
+    withClue("Invalid Long: out of range") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit(1231000000000L))).collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    withClue("Invalid Decimal: out of range") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 
2)))).collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    withClue("Invalid Decimal: fractional part") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    withClue("Invalid Double: out of range") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit(1231000000000.0))).collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    withClue("Invalid Double: fractional part") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit(123.1))).collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    withClue("Invalid Type") {
+      val e: SparkException = intercept[SparkException] {
+        df.select(checkedCast(lit("123.1"))).collect()
+      }
+      assert(e.getMessage.contains("was not numeric"))
+    }
+  }
+
   /**
    * Generates an explicit feedback dataset for testing ALS.
    * @param numUsers number of users
@@ -510,34 +575,35 @@ class ALSSuite
       (0, big, small, 0, big, small, 2.0),
       (1, 1L, 1d, 0, 0L, 0d, 5.0)
     ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", 
"rating")
+    val msg = "either out of Integer range or contained a fractional part"
     withClue("fit should fail when ids exceed integer range. ") {
       assert(intercept[SparkException] {
         als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
-      }.getCause.getMessage.contains("was out of Integer range"))
+      }.getCause.getMessage.contains(msg))
       assert(intercept[SparkException] {
         als.fit(df.select(df("user_small").as("user"), df("item"), 
df("rating")))
-      }.getCause.getMessage.contains("was out of Integer range"))
+      }.getCause.getMessage.contains(msg))
       assert(intercept[SparkException] {
         als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
-      }.getCause.getMessage.contains("was out of Integer range"))
+      }.getCause.getMessage.contains(msg))
       assert(intercept[SparkException] {
         als.fit(df.select(df("item_small").as("item"), df("user"), 
df("rating")))
-      }.getCause.getMessage.contains("was out of Integer range"))
+      }.getCause.getMessage.contains(msg))
     }
     withClue("transform should fail when ids exceed integer range. ") {
       val model = als.fit(df)
       assert(intercept[SparkException] {
         model.transform(df.select(df("user_big").as("user"), df("item"))).first
-      }.getMessage.contains("was out of Integer range"))
+      }.getMessage.contains(msg))
       assert(intercept[SparkException] {
         model.transform(df.select(df("user_small").as("user"), 
df("item"))).first
-      }.getMessage.contains("was out of Integer range"))
+      }.getMessage.contains(msg))
       assert(intercept[SparkException] {
         model.transform(df.select(df("item_big").as("item"), df("user"))).first
-      }.getMessage.contains("was out of Integer range"))
+      }.getMessage.contains(msg))
       assert(intercept[SparkException] {
         model.transform(df.select(df("item_small").as("item"), 
df("user"))).first
-      }.getMessage.contains("was out of Integer range"))
+      }.getMessage.contains(msg))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to