This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b249cb8af35 [SPARK-46538][ML] Fix the ambiguous column reference issue in `ALSModel.transform` b249cb8af35 is described below commit b249cb8af35588583a63785fdf9b683955fb7ce1 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Dec 29 09:27:22 2023 +0800 [SPARK-46538][ML] Fix the ambiguous column reference issue in `ALSModel.transform` ### What changes were proposed in this pull request? the column references in `ALSModel.transform` maybe ambiguous in some case ### Why are the changes needed? to fix a bug before this fix, the test fails with: ``` JVM stacktrace: org.apache.spark.sql.AnalysisException: [MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION] Resolved attribute(s) "features", "features" missing from "user", "item", "id", "features", "id", "features" in operator !Project [user#60, item#63, UDF(features#50, features#54) AS prediction#94]. Attribute(s) with the same name appear in the operation: "features", "features". Please check if the right attribute(s) are used. SQLSTATE: XX000; ``` and ``` pyspark.errors.exceptions.captured.AnalysisException: Column features#50, features#46 are ambiguous. It's probably because you joined several Datasets together, and some of these Datasets are the same. This column points to one of the Datasets but Spark is unable to figure out which one. Please alias the Datasets with different names via `Dataset.as` before joining them, and specify the column using qualified name, e.g. `df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also se [...] JVM stacktrace: org.apache.spark.sql.AnalysisException: Column features#50, features#46 are ambiguous. It's probably because you joined several Datasets together, and some of these Datasets are the same. This column points to one of the Datasets but Spark is unable to figure out which one. Please alias the Datasets with different names via `Dataset.as` before joining them, and specify the column using qualified name, e.g. `df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also set spark.sql.an [...] ``` ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44526 from zhengruifeng/ml_als_reference. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 1 + .../org/apache/spark/ml/recommendation/ALS.scala | 21 +++++-- python/pyspark/ml/tests/test_als.py | 68 ++++++++++++++++++++++ 3 files changed, 84 insertions(+), 6 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4ccef788ce8..8595e7ec0e6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -627,6 +627,7 @@ pyspark_ml = Module( "pyspark.ml.tuning", # unittests "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_als", "pyspark.ml.tests.test_base", "pyspark.ml.tests.test_evaluation", "pyspark.ml.tests.test_feature", diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 65c7d399a88..1e6be16ef62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -324,13 +324,22 @@ class ALSModel private[ml] ( // create a new column named map(predictionCol) by running the predict UDF. val validatedUsers = checkIntegers(dataset, $(userCol)) val validatedItems = checkIntegers(dataset, $(itemCol)) + + val validatedInputAlias = Identifiable.randomUID("__als_validated_input") + val itemFactorsAlias = Identifiable.randomUID("__als_item_factors") + val userFactorsAlias = Identifiable.randomUID("__als_user_factors") + val predictions = dataset - .join(userFactors, - validatedUsers === userFactors("id"), "left") - .join(itemFactors, - validatedItems === itemFactors("id"), "left") - .select(dataset("*"), - predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + .withColumns(Seq($(userCol), $(itemCol)), Seq(validatedUsers, validatedItems)) + .alias(validatedInputAlias) + .join(userFactors.alias(userFactorsAlias), + col(s"${validatedInputAlias}.${$(userCol)}") === col(s"${userFactorsAlias}.id"), "left") + .join(itemFactors.alias(itemFactorsAlias), + col(s"${validatedInputAlias}.${$(itemCol)}") === col(s"${itemFactorsAlias}.id"), "left") + .select(col(s"${validatedInputAlias}.*"), + predict(col(s"${userFactorsAlias}.features"), col(s"${itemFactorsAlias}.features")) + .alias($(predictionCol))) + getColdStartStrategy match { case ALSModel.Drop => predictions.na.drop("all", Seq($(predictionCol))) diff --git a/python/pyspark/ml/tests/test_als.py b/python/pyspark/ml/tests/test_als.py new file mode 100644 index 00000000000..8eec0d93776 --- /dev/null +++ b/python/pyspark/ml/tests/test_als.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest + +import pyspark.sql.functions as sf +from pyspark.ml.recommendation import ALS, ALSModel +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ALSTest(ReusedSQLTestCase): + def test_ambiguous_column(self): + data = self.spark.createDataFrame( + [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]], + ["user", "item", "rating"], + ) + model = ALS( + userCol="user", + itemCol="item", + ratingCol="rating", + numUserBlocks=10, + numItemBlocks=10, + maxIter=1, + seed=42, + ).fit(data) + + with tempfile.TemporaryDirectory() as d: + model.write().overwrite().save(d) + loaded_model = ALSModel().load(d) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": False}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": True}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_als import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org