This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4399755 [SPARK-38243][PYTHON][ML] Fix pyspark.ml.LogisticRegression.getThreshold error message logic 4399755 is described below commit 439975590cf4f21c2a548a2ac6231eb234e1a2f3 Author: zero323 <mszymkiew...@gmail.com> AuthorDate: Fri Feb 18 11:08:33 2022 -0600 [SPARK-38243][PYTHON][ML] Fix pyspark.ml.LogisticRegression.getThreshold error message logic ### What changes were proposed in this pull request? This PR replaces incorrect usage of `str.join` on a `List[float]` in `LogisticRegression.getThreshold`. ### Why are the changes needed? To avoid unexpected failure if method is used in case of multi-class classification. After this change, the following code: ```python from pyspark.ml.classification import LogisticRegression LogisticRegression(thresholds=[1.0, 2.0, 3.0]).getThreshold() ``` raises ```python Traceback (most recent call last): Input In [4] in <module> model.getThreshold() File /path/to/spark/python/pyspark/ml/classification.py:999 in getThreshold raise ValueError( ValueError: Logistic Regression getThreshold only applies to binary classification, but thresholds has length != 2. thresholds: [1.0, 2.0, 3.0] ``` instead of current ```python Traceback (most recent call last): Input In [7] in <module> model.getThreshold() File /path/to/spark/python/pyspark/ml/classification.py:1003 in getThreshold + ",".join(ts) TypeError: sequence item 0: expected str instance, float found ``` ### Does this PR introduce _any_ user-facing change? No. Bugfix. ### How was this patch tested? Manual testing. Closes #35558 from zero323/SPARK-38243. Authored-by: zero323 <mszymkiew...@gmail.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- python/pyspark/ml/classification.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 058740e..b791e6f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -999,8 +999,7 @@ class _LogisticRegressionParams( raise ValueError( "Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." - + " thresholds: " - + ",".join(ts) + + " thresholds: {ts}".format(ts=ts) ) return 1.0 / (1.0 + ts[0] / ts[1]) else: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org