Repository: spark Updated Branches: refs/heads/master 9b8eca65d -> b40546651
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation. The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible. The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)). ## How was this patch tested? New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics. Author: Nick Pentreath <ni...@za.ibm.com> Closes #12896 from MLnick/SPARK-14489-als-nan. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4054665 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4054665 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4054665 Branch: refs/heads/master Commit: b405466513bcc02cadf1477b6b682ace95d81658 Parents: 9b8eca6 Author: Nick Pentreath <ni...@za.ibm.com> Authored: Tue Feb 28 16:17:35 2017 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Tue Feb 28 16:17:35 2017 +0200 ---------------------------------------------------------------------- .../apache/spark/ml/recommendation/ALS.scala | 44 ++++++++++++++++- .../spark/ml/recommendation/ALSSuite.scala | 51 +++++++++++++++++++- python/pyspark/ml/recommendation.py | 30 ++++++++++-- 3 files changed, 116 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 97c8655..af00762 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo n.toInt } } + + /** + * Param for strategy for dealing with unknown or new users/items at prediction time. + * This may be useful in cross-validation or production scenarios, for handling user/item ids + * the model has not seen in the training data. + * Supported values: + * - "nan": predicted value for unknown ids will be NaN. + * - "drop": rows in the input DataFrame containing unknown ids will be dropped from + * the output DataFrame containing predictions. + * Default: "nan". + * @group expertParam + */ + val coldStartStrategy = new Param[String](this, "coldStartStrategy", + "strategy for dealing with unknown or new users/items at prediction time. This may be " + + "useful in cross-validation or production scenarios, for handling user/item ids the model " + + "has not seen in the training data. Supported values: " + + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", + (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + + /** @group expertGetParam */ + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase } /** @@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, - intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") + intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", + coldStartStrategy -> "nan") /** * Validates and transforms the input schema. @@ -248,6 +270,10 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) @@ -260,13 +286,19 @@ class ALSModel private[ml] ( Float.NaN } } - dataset + val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") .join(itemFactors, checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + getColdStartStrategy match { + case ALSModel.Drop => + predictions.na.drop("all", Seq($(predictionCol))) + case ALSModel.NaN => + predictions + } } @Since("1.3.0") @@ -290,6 +322,10 @@ class ALSModel private[ml] ( @Since("1.6.0") object ALSModel extends MLReadable[ALSModel] { + private val NaN = "nan" + private val Drop = "drop" + private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader @@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b923bac..c9e7b50 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -498,8 +498,8 @@ class ALSSuite (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== - act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== + act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } // check user/item ids falling outside of Int range @@ -547,6 +547,53 @@ class ALSSuite ALS.train(ratings) } } + + test("ALS cold start user/item prediction strategy") { + val spark = this.spark + import spark.implicits._ + import org.apache.spark.sql.functions._ + + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val data = ratings.toDF + val knownUser = data.select(max("user")).as[Int].first() + val unknownUser = knownUser + 10 + val knownItem = data.select(max("item")).as[Int].first() + val unknownItem = knownItem + 20 + val test = Seq( + (unknownUser, unknownItem), + (knownUser, unknownItem), + (unknownUser, knownItem), + (knownUser, knownItem) + ).toDF("user", "item") + + val als = new ALS().setMaxIter(1).setRank(1) + // default is 'nan' + val defaultModel = als.fit(data) + val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() + assert(defaultPredictions.length == 4) + assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) + assert(!defaultPredictions.last.isNaN) + + // check 'drop' strategy should filter out rows with unknown users/items + val dropPredictions = defaultModel + .setColdStartStrategy("drop") + .transform(test) + .select("prediction").as[Float].collect() + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + } + + test("case insensitive cold start param value") { + val spark = this.spark + import spark.implicits._ + val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1) + val data = ratings.toDF + val model = new ALS().fit(data) + Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => + model.setColdStartStrategy(s).transform(data) + } + } } class ALSCleanerSuite extends SparkFunSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/python/pyspark/ml/recommendation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e28d38b..43f82da 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) @@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ kwargs = self.setParams._input_kwargs @@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ return self.getOrDefault(self.finalStorageLevel) + @since("2.2.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org