[ https://issues.apache.org/jira/browse/SPARK-42747?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Sean R. Owen resolved SPARK-42747. ---------------------------------- Fix Version/s: 3.2.4 3.4.1 3.5.0 3.3.2 Assignee: Ruifeng Zheng Resolution: Fixed Resolved by https://github.com/apache/spark/pull/40367 > Fix incorrect internal status of LoR and AFT > -------------------------------------------- > > Key: SPARK-42747 > URL: https://issues.apache.org/jira/browse/SPARK-42747 > Project: Spark > Issue Type: Bug > Components: ML, PySpark > Affects Versions: 3.1.0, 3.2.0, 3.3.0, 3.4.0 > Reporter: Ruifeng Zheng > Assignee: Ruifeng Zheng > Priority: Major > Fix For: 3.2.4, 3.4.1, 3.5.0, 3.3.2 > > > LoR and AFT applied internal status to optimize prediction/transform, but the > status is not correctly updated in some case: > {code:java} > from pyspark.sql import Row > from pyspark.ml.classification import * > from pyspark.ml.linalg import Vectors > df = spark.createDataFrame( > [ > (1.0, 1.0, Vectors.dense(0.0, 5.0)), > (0.0, 2.0, Vectors.dense(1.0, 2.0)), > (1.0, 3.0, Vectors.dense(2.0, 1.0)), > (0.0, 4.0, Vectors.dense(3.0, 3.0)), > ], > ["label", "weight", "features"], > ) > lor = LogisticRegression(weightCol="weight") > model = lor.fit(df) > # status changes 1 > for t in [0.0, 0.1, 0.2, 0.5, 1.0]: > model.setThreshold(t).transform(df) > # status changes 2 > [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, > 0.2, 0.5, 1.0]] > for t in [0.0, 0.1, 0.2, 0.5, 1.0]: > print(t) > model.setThreshold(t).transform(df).show() > # <- error results > {code} > results: > {code:java} > 0.0 > +-----+------+---------+--------------------+--------------------+----------+ > |label|weight| features| rawPrediction| probability|prediction| > +-----+------+---------+--------------------+--------------------+----------+ > | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| > | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| > | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| > | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| > +-----+------+---------+--------------------+--------------------+----------+ > 0.1 > +-----+------+---------+--------------------+--------------------+----------+ > |label|weight| features| rawPrediction| probability|prediction| > +-----+------+---------+--------------------+--------------------+----------+ > | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| > | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| > | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| > | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| > +-----+------+---------+--------------------+--------------------+----------+ > 0.2 > +-----+------+---------+--------------------+--------------------+----------+ > |label|weight| features| rawPrediction| probability|prediction| > +-----+------+---------+--------------------+--------------------+----------+ > | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| > | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| > | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| > | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| > +-----+------+---------+--------------------+--------------------+----------+ > 0.5 > +-----+------+---------+--------------------+--------------------+----------+ > |label|weight| features| rawPrediction| probability|prediction| > +-----+------+---------+--------------------+--------------------+----------+ > | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| > | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| > | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| > | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| > +-----+------+---------+--------------------+--------------------+----------+ > 1.0 > +-----+------+---------+--------------------+--------------------+----------+ > |label|weight| features| rawPrediction| probability|prediction| > +-----+------+---------+--------------------+--------------------+----------+ > | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| > | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| > | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| > | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| > +-----+------+---------+--------------------+--------------------+----------+ > {code} -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org