Repository: spark Updated Branches: refs/heads/master 369a148e5 -> 70f9d7f71
[SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe ## What changes were proposed in this pull request? This is a simple implementation of RecommendForAllUsers & RecommendForAllItems for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper on the RDD implementation). Haven't benchmarked against a wrapper, but unit test examples do work. ## How was this patch tested? Unit tests ``` $ build/sbt > mllib/testOnly *ALSSuite -- -z "recommendFor" > mllib/testOnly ``` Author: Your Name <y...@example.com> Author: sueann <sue...@databricks.com> Closes #17090 from sueann/SPARK-19535. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70f9d7f7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70f9d7f7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70f9d7f7 Branch: refs/heads/master Commit: 70f9d7f71c63d2b1fdfed75cb7a59285c272a62b Parents: 369a148 Author: Sue Ann Hong <sue...@databricks.com> Authored: Sun Mar 5 16:49:31 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Sun Mar 5 16:49:31 2017 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/recommendation/ALS.scala | 79 ++++++++++++++-- .../ml/recommendation/TopByKeyAggregator.scala | 60 +++++++++++++ .../spark/ml/recommendation/ALSSuite.scala | 94 ++++++++++++++++++++ .../TopByKeyAggregatorSuite.scala | 73 +++++++++++++++ 4 files changed, 297 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 799e881..60dd736 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -284,18 +285,20 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => + if (featuresA != null && featuresB != null) { + // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for + // potential optimization. + blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + } else { + Float.NaN + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) - } else { - Float.NaN - } - } val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol))) === userFactors("id"), "left") @@ -327,6 +330,64 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + + /** + * Returns top `numItems` items recommended for each user, for all users. + * @param numItems max number of recommendations for each user + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllUsers(numItems: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) + } + + /** + * Returns top `numUsers` users recommended for each item, for all items. + * @param numUsers max number of recommendations for each item + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllItems(numUsers: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Makes recommendations for all users (or items). + * @param srcFactors src factors for which to generate recommendations + * @param dstFactors dst factors used to make recommendations + * @param srcOutputColumn name of the column for the source ID in the output DataFrame + * @param dstOutputColumn name of the column for the destination ID in the output DataFrame + * @param num max number of recommendations for each record + * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are + * stored as an array of (dstOutputColumn: Int, rating: Float) Rows. + */ + private def recommendForAll( + srcFactors: DataFrame, + dstFactors: DataFrame, + srcOutputColumn: String, + dstOutputColumn: String, + num: Int): DataFrame = { + import srcFactors.sparkSession.implicits._ + + val ratings = srcFactors.crossJoin(dstFactors) + .select( + srcFactors("id"), + dstFactors("id"), + predict(srcFactors("features"), dstFactors("features"))) + // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) + val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + } } @Since("1.6.0") http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala new file mode 100644 index 0000000..517179c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds + * the top `num` K2 items based on the given Ordering. + */ +private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] + (num: Int, ord: Ordering[(K2, V)]) + extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { + + override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + + override def reduce( + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { + q += {(a._2, a._3)} + } + + override def merge( + q1: BoundedPriorityQueue[(K2, V)], + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { + q1 ++= q2 + } + + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { + r.toArray.sorted(ord.reverse) + } + + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { + Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + } + + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() +} http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c8228dd..e494ea8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.WrappedArray import scala.collection.JavaConverters._ import scala.language.existentials @@ -660,6 +661,99 @@ class ALSSuite model.setColdStartStrategy(s).transform(data) } } + + private def getALSModel = { + val spark = this.spark + import spark.implicits._ + + val userFactors = Seq( + (0, Array(6.0f, 4.0f)), + (1, Array(3.0f, 4.0f)), + (2, Array(3.0f, 6.0f)) + ).toDF("id", "features") + val itemFactors = Seq( + (3, Array(5.0f, 6.0f)), + (4, Array(6.0f, 2.0f)), + (5, Array(3.0f, 6.0f)), + (6, Array(4.0f, 1.0f)) + ).toDF("id", "features") + val als = new ALS().setRank(2) + new ALSModel(als.uid, als.getRank, userFactors, itemFactors) + .setUserCol("user") + .setItemCol("item") + } + + test("recommendForAllUsers with k < num_items") { + val topItems = getALSModel.recommendForAllUsers(2) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllUsers with k = num_items") { + val topItems = getALSModel.recommendForAllUsers(4) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllItems with k < num_users") { + val topUsers = getALSModel.recommendForAllItems(2) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f)), + 4 -> Array((0, 44f), (2, 30f)), + 5 -> Array((2, 45f), (0, 42f)), + 6 -> Array((0, 28f), (2, 18f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + test("recommendForAllItems with k = num_users") { + val topUsers = getALSModel.recommendForAllItems(3) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f), (1, 39f)), + 4 -> Array((0, 44f), (2, 30f), (1, 26f)), + 5 -> Array((2, 45f), (0, 42f), (1, 33f)), + 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + private def checkRecommendations( + topK: DataFrame, + expected: Map[Int, Array[(Int, Float)]], + dstColName: String): Unit = { + val spark = this.spark + import spark.implicits._ + + assert(topK.columns.contains("recommendations")) + topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) => + assert(recs === expected(id)) + } + topK.collect().foreach { row => + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs(0).fieldIndex(dstColName) == 0) + assert(recs(0).fieldIndex("rating") == 1) + } + } } class ALSCleanerSuite extends SparkFunSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala new file mode 100644 index 0000000..5e763a8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + + +class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) + Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (2, 3, 51f), + (2, 5, 45f), + (2, 6, 18f) + ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) + } + + test("topByKey with k < #items") { + val topK = getTopK(2) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkTopK(topK, expected) + } + + test("topByKey with k > #items") { + val topK = getTopK(5) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) + ) + checkTopK(topK, expected) + } + + private def checkTopK( + topK: Dataset[(Int, Array[(Int, Float)])], + expected: Map[Int, Array[(Int, Float)]]): Unit = { + topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org