Repository: spark
Updated Branches:
  refs/heads/branch-2.0 d005f76e6 -> 0da8bce0e


[SPARK-14891][ML] Add schema validation for ALS

This PR adds schema validation to `ml`'s ALS and ALSModel. Currently, no schema 
validation was performed as `transformSchema` was never called in `ALS.fit` or 
`ALSModel.transform`. Furthermore, due to no schema validation, if users passed 
in Long (or Float etc) ids, they would be silently cast to Int with no warning 
or error thrown.

With this PR, ALS now supports all numeric types for `user`, `item`, and 
`rating` columns. The rating column is cast to `Float` and the user and item 
cols are cast to `Int` (as is the case currently) - however for user/item, the 
cast throws an error if the value is outside integer range. Behavior for rating 
col is unchanged (as it is not an issue).

## How was this patch tested?
New test cases in `ALSSuite`.

Author: Nick Pentreath <ni...@za.ibm.com>

Closes #12762 from MLnick/SPARK-14891-als-validate-schema.

(cherry picked from commit e8b79afa024123f9d4ceaf0a1043a7e37d913a8d)
Signed-off-by: Nick Pentreath <ni...@za.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0da8bce0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0da8bce0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0da8bce0

Branch: refs/heads/branch-2.0
Commit: 0da8bce0e3fcf6a7f40b5e23e57ce45795926432
Parents: d005f76
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Wed May 18 21:13:12 2016 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Wed May 18 21:13:29 2016 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 55 +++++++++++++-----
 .../spark/ml/recommendation/ALSSuite.scala      | 61 ++++++++++++++++++--
 .../apache/spark/ml/util/MLTestingUtils.scala   | 45 +++++++++++++++
 python/pyspark/ml/recommendation.py             |  8 +--
 4 files changed, 147 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0da8bce0/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 509c944..f257382 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, 
StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, 
SortDataFormat, Sorter}
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
  */
 private[recommendation] trait ALSModelParams extends Params with 
HasPredictionCol {
   /**
-   * Param for the column name for user ids.
+   * Param for the column name for user ids. Ids must be integers. Other
+   * numeric types are supported for this column, but will be cast to integers 
as long as they
+   * fall within the integer value range.
    * Default: "user"
    * @group param
    */
-  val userCol = new Param[String](this, "userCol", "column name for user ids")
+  val userCol = new Param[String](this, "userCol", "column name for user ids. 
Ids must be within " +
+    "the integer value range.")
 
   /** @group getParam */
   def getUserCol: String = $(userCol)
 
   /**
-   * Param for the column name for item ids.
+   * Param for the column name for item ids. Ids must be integers. Other
+   * numeric types are supported for this column, but will be cast to integers 
as long as they
+   * fall within the integer value range.
    * Default: "item"
    * @group param
    */
-  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+  val itemCol = new Param[String](this, "itemCol", "column name for item ids. 
Ids must be within " +
+    "the integer value range.")
 
   /** @group getParam */
   def getItemCol: String = $(itemCol)
+
+  /**
+   * Attempts to safely cast a user/item id to an Int. Throws an exception if 
the value is
+   * out of integer range.
+   */
+  protected val checkedCast = udf { (n: Double) =>
+    if (n > Int.MaxValue || n < Int.MinValue) {
+      throw new IllegalArgumentException(s"ALS only supports values in Integer 
range for columns " +
+        s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
+    } else {
+      n.toInt
+    }
+  }
 }
 
 /**
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends 
ALSModelParams with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
-    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
-    val ratingType = schema($(ratingCol)).dataType
-    require(ratingType == FloatType || ratingType == DoubleType)
+    // user and item will be cast to Int
+    SchemaUtils.checkNumericType(schema, $(userCol))
+    SchemaUtils.checkNumericType(schema, $(itemCol))
+    // rating will be cast to Float
+    SchemaUtils.checkNumericType(schema, $(ratingCol))
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 }
@@ -232,6 +252,7 @@ class ALSModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema)
     // Register a UDF for DataFrame, and then
     // create a new column named map(predictionCol) by running the predict UDF.
     val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -242,16 +263,19 @@ class ALSModel private[ml] (
       }
     }
     dataset
-      .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
-      .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+      .join(userFactors,
+        checkedCast(dataset($(userCol)).cast(DoubleType)) === 
userFactors("id"), "left")
+      .join(itemFactors,
+        checkedCast(dataset($(itemCol)).cast(DoubleType)) === 
itemFactors("id"), "left")
       .select(dataset("*"),
         predict(userFactors("features"), 
itemFactors("features")).as($(predictionCol)))
   }
 
   @Since("1.3.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
-    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+    // user and item will be cast to Int
+    SchemaUtils.checkNumericType(schema, $(userCol))
+    SchemaUtils.checkNumericType(schema, $(itemCol))
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) 
extends Estimator[ALSModel]
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): ALSModel = {
+    transformSchema(dataset.schema)
     import dataset.sparkSession.implicits._
+
     val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else 
lit(1.0f)
     val ratings = dataset
-      .select(col($(userCol)).cast(IntegerType), 
col($(itemCol)).cast(IntegerType), r)
+      .select(checkedCast(col($(userCol)).cast(DoubleType)),
+        checkedCast(col($(itemCol)).cast(DoubleType)), r)
       .rdd
       .map { row =>
         Rating(row.getInt(0), row.getInt(1), row.getFloat(2))

http://git-wip-us.apache.org/repos/asf/spark/blob/0da8bce0/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bbfc415..59b5edc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
 import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.{FloatType, IntegerType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -205,7 +206,6 @@ class ALSSuite
 
   /**
    * Generates an explicit feedback dataset for testing ALS.
-   *
    * @param numUsers number of users
    * @param numItems number of items
    * @param rank rank
@@ -246,7 +246,6 @@ class ALSSuite
 
   /**
    * Generates an implicit feedback dataset for testing ALS.
-   *
    * @param numUsers number of users
    * @param numItems number of items
    * @param rank rank
@@ -265,7 +264,6 @@ class ALSSuite
 
   /**
    * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
-   *
    * @param size number of users/items
    * @param rank number of features
    * @param random random number generator
@@ -284,7 +282,6 @@ class ALSSuite
 
   /**
    * Test ALS using the given training/test splits and parameters.
-   *
    * @param training training dataset
    * @param test test dataset
    * @param rank rank of the matrix factorization
@@ -486,6 +483,62 @@ class ALSSuite
     assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
     assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
   }
+
+  test("input type validation") {
+    val spark = this.spark
+    import spark.implicits._
+
+    // check that ALS can handle all numeric types for rating column
+    // and user/item columns (when the user/item ids are within Int range)
+    val als = new ALS().setMaxIter(1).setRank(1)
+    Seq(("user", IntegerType), ("item", IntegerType), ("rating", 
FloatType)).foreach {
+      case (colName, sqlType) =>
+        MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+          (ex, act) =>
+            ex.userFactors.first().getSeq[Float](1) === 
act.userFactors.first.getSeq[Float](1)
+        } { (ex, act, _) =>
+          ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
+            act.transform(_: DataFrame).select("prediction").first.getFloat(0) 
absTol 1e-6
+        }
+    }
+    // check user/item ids falling outside of Int range
+    val big = Int.MaxValue.toLong + 1
+    val small = Int.MinValue.toDouble - 1
+    val df = Seq(
+      (0, 0L, 0d, 1, 1L, 1d, 3.0),
+      (0, big, small, 0, big, small, 2.0),
+      (1, 1L, 1d, 0, 0L, 0d, 5.0)
+    ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", 
"rating")
+    withClue("fit should fail when ids exceed integer range. ") {
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("user_small").as("user"), df("item"), 
df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("item_small").as("item"), df("user"), 
df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+    }
+    withClue("transform should fail when ids exceed integer range. ") {
+      val model = als.fit(df)
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("user_big").as("user"), df("item"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("user_small").as("user"), 
df("item"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("item_big").as("item"), df("user"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("item_small").as("item"), 
df("user"))).first
+      }.getMessage.contains("was out of Integer range"))
+    }
+  }
 }
 
 class ALSCleanerSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/0da8bce0/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 6aae625..80b9769 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.sql.functions._
@@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite {
       "Column label must be of type NumericType but was actually of type 
StringType"))
   }
 
+  def checkNumericTypesALS(
+      estimator: ALS,
+      spark: SparkSession,
+      column: String,
+      baseType: NumericType)
+      (check: (ALSModel, ALSModel) => Unit)
+      (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
+    val dfs = genRatingsDFWithNumericCols(spark, column)
+    val expected = estimator.fit(dfs(baseType))
+    val actuals = dfs.keys.filter(_ != baseType).map(t => (t, 
estimator.fit(dfs(t))))
+    actuals.foreach { case (_, actual) => check(expected, actual) }
+    actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
+
+    val baseDF = dfs(baseType)
+    val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+    val cols = Seq(col(column).cast(StringType)) ++ others
+    val strDF = baseDF.select(cols: _*)
+    val thrown = intercept[IllegalArgumentException] {
+      estimator.fit(strDF)
+    }
+    assert(thrown.getMessage.contains(
+      s"$column must be of type NumericType but was actually of type 
StringType"))
+  }
+
   def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): 
Unit = {
     val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
     val expected = evaluator.evaluate(dfs(DoubleType))
@@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite {
       }.toMap
   }
 
+  def genRatingsDFWithNumericCols(
+      spark: SparkSession,
+      column: String): Map[NumericType, DataFrame] = {
+    val df = spark.createDataFrame(Seq(
+      (0, 10, 1.0),
+      (1, 20, 2.0),
+      (2, 30, 3.0),
+      (3, 40, 4.0),
+      (4, 50, 5.0)
+    )).toDF("user", "item", "rating")
+
+    val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+    val types: Seq[NumericType] =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, 
DecimalType(10, 0))
+    types.map { t =>
+      val cols = Seq(col(column).cast(t)) ++ others
+      t -> df.select(cols: _*)
+    }.toMap
+  }
+
   def genEvaluatorDFWithNumericLabelCol(
       spark: SparkSession,
       labelColName: String = "label",

http://git-wip-us.apache.org/repos/asf/spark/blob/0da8bce0/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index d7cb658..86c00d9 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, 
HasMaxIter, HasPredictionCol, Ha
                           typeConverter=TypeConverters.toBoolean)
     alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
                   typeConverter=TypeConverters.toFloat)
-    userCol = Param(Params._dummy(), "userCol", "column name for user ids",
-                    typeConverter=TypeConverters.toString)
-    itemCol = Param(Params._dummy(), "itemCol", "column name for item ids",
-                    typeConverter=TypeConverters.toString)
+    userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids 
must be within " +
+                    "the integer value range.", 
typeConverter=TypeConverters.toString)
+    itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids 
must be within " +
+                    "the integer value range.", 
typeConverter=TypeConverters.toString)
     ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
                       typeConverter=TypeConverters.toString)
     nonnegative = Param(Params._dummy(), "nonnegative",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to