This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 4c79e10 [SPARK-35142][PYTHON][ML][3.0] Fix incorrect return type for `rawPredictionUDF` in `OneVsRestModel` 4c79e10 is described below commit 4c79e1074cc45c6189fbbc8270a40d3ff5006256 Author: harupy <17039389+har...@users.noreply.github.com> AuthorDate: Sat Apr 24 14:36:44 2021 -0700 [SPARK-35142][PYTHON][ML][3.0] Fix incorrect return type for `rawPredictionUDF` in `OneVsRestModel` ### What changes were proposed in this pull request? This PR backports https://github.com/apache/spark/pull/32245. Fixes incorrect return type for `rawPredictionUDF` in `OneVsRestModel`. ### Why are the changes needed? Bugfix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #32275 from harupy/backport-35142-3.0. Authored-by: harupy <17039389+har...@users.noreply.github.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/ml/classification.py | 4 ++-- python/pyspark/ml/tests/test_algorithms.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1392bc7..5523b6c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -30,7 +30,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams, \ JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc -from pyspark.ml.linalg import Vectors +from pyspark.ml.linalg import Vectors, VectorUDT from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType @@ -2724,7 +2724,7 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): predArray.append(x) return Vectors.dense(predArray) - rawPredictionUDF = udf(func) + rawPredictionUDF = udf(func, VectorUDT()) aggregatedDataset = aggregatedDataset.withColumn( self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])) diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index 2faf2d9..90fe59f 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -25,7 +25,7 @@ from pyspark.ml.classification import FMClassifier, LogisticRegression, \ MultilayerPerceptronClassifier, OneVsRest from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel from pyspark.ml.fpm import FPGrowth -from pyspark.ml.linalg import Matrices, Vectors +from pyspark.ml.linalg import Matrices, Vectors, DenseVector from pyspark.ml.recommendation import ALS from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression from pyspark.sql import Row @@ -116,7 +116,19 @@ class OneVsRestTests(SparkSessionTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"]) - def test_parallelism_doesnt_change_output(self): + def test_raw_prediction_column_is_of_vector_type(self): + # SPARK-35142: `OneVsRestModel` outputs raw prediction as a string column + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, parallelism=1) + model = ovr.fit(df) + row = model.transform(df).head() + self.assertIsInstance(row["rawPrediction"], DenseVector) + + def test_parallelism_does_not_change_output(self): df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org