This is an automated email from the ASF dual-hosted git repository. holden pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 91e64e2 [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator 91e64e2 is described below commit 91e64e24d54287b1e4564358a2ef2bc8c0e6a22b Author: Huaxin Gao <huax...@us.ibm.com> AuthorDate: Fri Feb 8 09:46:54 2019 -0800 [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator ## What changes were proposed in this pull request? add weightCol for python version of MulticlassClassificationEvaluator and MulticlassMetrics ## How was this patch tested? add doc test Closes #23157 from huaxingao/spark-26185. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Holden Karau <hol...@pigscanfly.ca> --- python/pyspark/ml/evaluation.py | 23 +++++++++++++++------ python/pyspark/mllib/evaluation.py | 42 +++++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 8eaf076..f563a2d 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -22,7 +22,7 @@ from pyspark import since, keyword_only from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \ - HasFeaturesCol + HasFeaturesCol, HasWeightCol from pyspark.ml.common import inherit_doc from pyspark.ml.util import JavaMLReadable, JavaMLWritable @@ -257,7 +257,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, @inherit_doc -class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -279,6 +279,17 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path) >>> str(evaluator2.getPredictionCol()) 'prediction' + >>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0), + ... (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), + ... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)] + >>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label", "weight"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", + ... weightCol="weight") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) + 0.66... .. versionadded:: 1.5.0 """ @@ -289,10 +300,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio @keyword_only def __init__(self, predictionCol="prediction", labelCol="label", - metricName="f1"): + metricName="f1", weightCol=None): """ __init__(self, predictionCol="prediction", labelCol="label", \ - metricName="f1") + metricName="f1", weightCol=None) """ super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( @@ -318,10 +329,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio @keyword_only @since("1.5.0") def setParams(self, predictionCol="prediction", labelCol="label", - metricName="f1"): + metricName="f1", weightCol=None): """ setParams(self, predictionCol="prediction", labelCol="label", \ - metricName="f1") + metricName="f1", weightCol=None) Sets params for multiclass classification evaluator. """ kwargs = self._input_kwargs diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 6ca6df6..b028394 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -162,7 +162,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels: an RDD of (prediction, label) pairs. + :param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) @@ -191,16 +191,48 @@ class MulticlassMetrics(JavaModelWrapper): 0.66... >>> metrics.weightedFMeasure(2.0) 0.65... + >>> predAndLabelsWithOptWeight = sc.parallelize([(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), + ... (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), + ... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]) + >>> metrics = MulticlassMetrics(predAndLabelsWithOptWeight) + >>> metrics.confusionMatrix().toArray() + array([[ 2., 1., 1.], + [ 1., 3., 0.], + [ 0., 0., 1.]]) + >>> metrics.falsePositiveRate(0.0) + 0.2... + >>> metrics.precision(1.0) + 0.75... + >>> metrics.recall(2.0) + 1.0... + >>> metrics.fMeasure(0.0, 2.0) + 0.52... + >>> metrics.accuracy + 0.66... + >>> metrics.weightedFalsePositiveRate + 0.19... + >>> metrics.weightedPrecision + 0.68... + >>> metrics.weightedRecall + 0.66... + >>> metrics.weightedFMeasure() + 0.66... + >>> metrics.weightedFMeasure(2.0) + 0.65... .. versionadded:: 1.4.0 """ - def __init__(self, predictionAndLabels): - sc = predictionAndLabels.ctx + def __init__(self, predAndLabelsWithOptWeight): + sc = predAndLabelsWithOptWeight.ctx sql_ctx = SQLContext.getOrCreate(sc) - df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ + numCol = len(predAndLabelsWithOptWeight.first()) + schema = StructType([ StructField("prediction", DoubleType(), nullable=False), - StructField("label", DoubleType(), nullable=False)])) + StructField("label", DoubleType(), nullable=False)]) + if (numCol == 3): + schema.add("weight", DoubleType(), False) + df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org