Repository: spark Updated Branches: refs/heads/master 3a44aebd0 -> 27b98e99d
[SPARK-12380] [PYSPARK] use SQLContext.getOrCreate in mllib MLlib should use SQLContext.getOrCreate() instead of creating new SQLContext. Author: Davies Liu <dav...@databricks.com> Closes #10338 from davies/create_context. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27b98e99 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27b98e99 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27b98e99 Branch: refs/heads/master Commit: 27b98e99d21a0cc34955337f82a71a18f9220ab2 Parents: 3a44aeb Author: Davies Liu <dav...@databricks.com> Authored: Wed Dec 16 15:48:11 2015 -0800 Committer: Davies Liu <davies....@gmail.com> Committed: Wed Dec 16 15:48:11 2015 -0800 ---------------------------------------------------------------------- python/pyspark/mllib/common.py | 6 +++--- python/pyspark/mllib/evaluation.py | 10 +++++----- python/pyspark/mllib/feature.py | 4 +--- 3 files changed, 9 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/27b98e99/python/pyspark/mllib/common.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a439a48..9fda1b1 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -102,7 +102,7 @@ def _java2py(sc, r, encoding="bytes"): return RDD(jrdd, sc) if clsName == 'DataFrame': - return DataFrame(r, SQLContext(sc)) + return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) @@ -125,7 +125,7 @@ def callJavaFunc(sc, func, *args): def callMLlibFunc(name, *args): """ Call API in PythonMLLibAPI """ - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) @@ -135,7 +135,7 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ def __init__(self, java_model): - self._sc = SparkContext._active_spark_context + self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self): http://git-wip-us.apache.org/repos/asf/spark/blob/27b98e99/python/pyspark/mllib/evaluation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 8c87ee9..22e68ea 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -44,7 +44,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ StructField("score", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -103,7 +103,7 @@ class RegressionMetrics(JavaModelWrapper): def __init__(self, predictionAndObservations): sc = predictionAndObservations.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("observation", DoubleType(), nullable=False)])) @@ -197,7 +197,7 @@ class MulticlassMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -338,7 +338,7 @@ class RankingMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_model = callMLlibFunc("newRankingMetrics", df._jdf) @@ -424,7 +424,7 @@ class MultilabelMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics http://git-wip-us.apache.org/repos/asf/spark/blob/27b98e99/python/pyspark/mllib/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7254679..acd7ec5 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ if sys.version >= '3': from py4j.protocol import Py4JJavaError -from pyspark import SparkContext, since +from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -100,8 +100,6 @@ class Normalizer(VectorTransformer): :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ - sc = SparkContext._active_spark_context - assert sc is not None, "SparkContext should be initialized first" if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org