Repository: spark Updated Branches: refs/heads/master 2fb12b0a3 -> 1db8feab8
[SPARK-15402][ML][PYSPARK] PySpark ml.evaluation should support save/load ## What changes were proposed in this pull request? Since ```ml.evaluation``` has supported save/load at Scala side, supporting it at Python side is very straightforward and easy. ## How was this patch tested? Add python doctest. Author: Yanbo Liang <yblia...@gmail.com> Closes #13194 from yanboliang/spark-15402. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1db8feab Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1db8feab Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1db8feab Branch: refs/heads/master Commit: 1db8feab8c564053c05e8bdc1a7f5026fd637d4f Parents: 2fb12b0 Author: Yanbo Liang <yblia...@gmail.com> Authored: Fri Oct 14 04:17:03 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Fri Oct 14 04:17:03 2016 -0700 ---------------------------------------------------------------------- python/pyspark/ml/evaluation.py | 45 ++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1db8feab/python/pyspark/ml/evaluation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 1fe8772..7aa16fa 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -22,6 +22,7 @@ from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.common import inherit_doc +from pyspark.ml.util import JavaMLReadable, JavaMLWritable __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', 'MulticlassClassificationEvaluator'] @@ -103,7 +104,8 @@ class JavaEvaluator(JavaParams, Evaluator): @inherit_doc -class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) 0.83... + >>> bce_path = temp_path + "/bce" + >>> evaluator.save(bce_path) + >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path) + >>> str(evaluator2.getRawPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -172,7 +179,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction @inherit_doc -class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) 2.649... + >>> re_path = temp_path + "/re" + >>> evaluator.save(re_path) + >>> evaluator2 = RegressionEvaluator.load(re_path) + >>> str(evaluator2.getPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -244,7 +257,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): @inherit_doc -class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) 0.66... + >>> mce_path = temp_path + "/mce" + >>> evaluator.save(mce_path) + >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path) + >>> str(evaluator2.getPredictionCol()) + 'prediction' .. versionadded:: 1.5.0 """ @@ -311,19 +330,27 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio if __name__ == "__main__": import doctest + import tempfile + import pyspark.ml.evaluation from pyspark.sql import SparkSession - globs = globals().copy() + globs = pyspark.ml.evaluation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: spark = SparkSession.builder\ .master("local[2]")\ .appName("ml.evaluation tests")\ .getOrCreate() - sc = spark.sparkContext - globs['sc'] = sc globs['spark'] = spark - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) - spark.stop() + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org