[ https://issues.apache.org/jira/browse/SPARK-35142?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17326517#comment-17326517 ]
Apache Spark commented on SPARK-35142: -------------------------------------- User 'harupy' has created a pull request for this issue: https://github.com/apache/spark/pull/32275 > `OneVsRest` classifier uses incorrect data type for `rawPrediction` column > -------------------------------------------------------------------------- > > Key: SPARK-35142 > URL: https://issues.apache.org/jira/browse/SPARK-35142 > Project: Spark > Issue Type: Bug > Components: ML > Affects Versions: 3.0.0, 3.0.2, 3.1.0, 3.1.1 > Reporter: Harutaka Kawamura > Priority: Major > > `OneVsRest` classifier uses an incorrect data type for the `rawPrediction` > column. > Code to reproduce the issue: > {code:java} > from pyspark.ml.classification import LogisticRegression, OneVsRest > from pyspark.ml.linalg import Vectors > from pyspark.sql import SparkSession > from sklearn.datasets import load_iris > spark = SparkSession.builder.getOrCreate() > X, y = load_iris(return_X_y=True) > df = spark.createDataFrame( > [(Vectors.dense(features), int(label)) for features, label in zip(X, y)], > ["features", "label"] > ) > train, test = df.randomSplit([0.8, 0.2]) > lor = LogisticRegression(maxIter=5) > ovr = OneVsRest(classifier=lor) > ovrModel = ovr.fit(train) > pred = ovrModel.transform(test) > pred.printSchema() > # This prints out: > # root > # |-- features: vector (nullable = true) > # |-- label: long (nullable = true) > # |-- rawPrediction: string (nullable = true) # <- should not be string > # |-- prediction: double (nullable = true) > # pred.show() # this fails because of the incorrect datatype{code} > I ran the code above using GitHub Actiosn: > [https://github.com/harupy/SPARK-35142/pull/1] > > It looks like the UDF to compute the `rawPrediction` column is generated > without specyfing the return type: > > [https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/python/pyspark/ml/classification.py#L3154] > {code:java} > rawPredictionUDF = udf(func) > {code} > -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org