Repository: spark
Updated Branches:
  refs/heads/branch-2.0 220b9a08e -> f3162b96d


[SPARK-15464][ML][MLLIB][SQL][TESTS] Replace SQLContext and SparkContext with 
SparkSession using builder pattern in python test code

## What changes were proposed in this pull request?

Replace SQLContext and SparkContext with SparkSession using builder pattern in 
python test code.

## How was this patch tested?

Existing test.

Author: WeichenXu <weichenxu...@outlook.com>

Closes #13242 from WeichenXu123/python_doctest_update_sparksession.

(cherry picked from commit a15ca5533db91fefaf3248255a59c4d94eeda1a9)
Signed-off-by: Andrew Or <and...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f3162b96
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f3162b96
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f3162b96

Branch: refs/heads/branch-2.0
Commit: f3162b96da4f61524c11150904916734c0e949ab
Parents: 220b9a0
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Mon May 23 18:14:48 2016 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Mon May 23 18:14:58 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py        |  38 +++---
 python/pyspark/ml/clustering.py            |  22 ++--
 python/pyspark/ml/evaluation.py            |  20 ++--
 python/pyspark/ml/feature.py               |  66 +++++-----
 python/pyspark/ml/recommendation.py        |  18 +--
 python/pyspark/ml/regression.py            |  46 +++----
 python/pyspark/ml/tuning.py                |  18 +--
 python/pyspark/mllib/classification.py     |  10 +-
 python/pyspark/mllib/evaluation.py         |  10 +-
 python/pyspark/mllib/feature.py            |  10 +-
 python/pyspark/mllib/fpm.py                |   9 +-
 python/pyspark/mllib/linalg/distributed.py |  12 +-
 python/pyspark/mllib/random.py             |  10 +-
 python/pyspark/mllib/regression.py         |  10 +-
 python/pyspark/mllib/stat/_statistics.py   |  10 +-
 python/pyspark/mllib/tree.py               |   9 +-
 python/pyspark/mllib/util.py               |  10 +-
 python/pyspark/sql/catalog.py              |  14 ++-
 python/pyspark/sql/column.py               |  12 +-
 python/pyspark/sql/conf.py                 |  11 +-
 python/pyspark/sql/functions.py            | 153 ++++++++++++------------
 python/pyspark/sql/group.py                |  12 +-
 22 files changed, 298 insertions(+), 232 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index a1c3f72..ea660d7 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -498,7 +498,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
 
     >>> from pyspark.ml.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -512,7 +512,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     1
     >>> model.featureImportances
     SparseVector(1, {0: 1.0})
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> result = model.transform(test0).head()
     >>> result.prediction
     0.0
@@ -520,7 +520,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     DenseVector([1.0, 0.0])
     >>> result.rawPrediction
     DenseVector([1.0, 0.0])
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
 
@@ -627,7 +627,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     >>> from numpy import allclose
     >>> from pyspark.ml.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -639,7 +639,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     SparseVector(1, {0: 1.0})
     >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
     True
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> result = model.transform(test0).head()
     >>> result.prediction
     0.0
@@ -647,7 +647,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     0
     >>> numpy.argmax(result.rawPrediction)
     0
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
     >>> rfc_path = temp_path + "/rfc"
@@ -754,7 +754,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol
     >>> from numpy import allclose
     >>> from pyspark.ml.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -766,10 +766,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol
     SparseVector(1, {0: 1.0})
     >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
     True
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
     >>> gbtc_path = temp_path + "gbtc"
@@ -885,7 +885,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol, H
 
     >>> from pyspark.sql import Row
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
     ...     Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
     ...     Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
@@ -1029,7 +1029,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, 
HasFeaturesCol, HasLabelCol,
     Number of outputs has to be equal to the total number of labels.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (0.0, Vectors.dense([0.0, 0.0])),
     ...     (1.0, Vectors.dense([0.0, 1.0])),
     ...     (1.0, Vectors.dense([1.0, 0.0])),
@@ -1040,7 +1040,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, 
HasFeaturesCol, HasLabelCol,
     [2, 5, 2]
     >>> model.weights.size
     27
-    >>> testDF = sqlContext.createDataFrame([
+    >>> testDF = spark.createDataFrame([
     ...     (Vectors.dense([1.0, 0.0]),),
     ...     (Vectors.dense([0.0, 0.0]),)], ["features"])
     >>> model.transform(testDF).show()
@@ -1467,21 +1467,23 @@ class OneVsRestModel(Model, OneVsRestParams, 
MLReadable, MLWritable):
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.classification
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = pyspark.ml.classification.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.classification tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.classification tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     import tempfile
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        sc.stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index ac7183d..a457904 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -73,7 +73,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, 
HasPredictionCol, HasMaxIte
     ...         (Vectors.dense([0.75, 0.935]),),
     ...         (Vectors.dense([-0.83, -0.68]),),
     ...         (Vectors.dense([-0.91, -0.76]),)]
-    >>> df = sqlContext.createDataFrame(data, ["features"])
+    >>> df = spark.createDataFrame(data, ["features"])
     >>> gm = GaussianMixture(k=3, tol=0.0001,
     ...                      maxIter=10, seed=10)
     >>> model = gm.fit(df)
@@ -197,7 +197,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, 
HasPredictionCol, HasMaxIter, HasTol
     >>> from pyspark.ml.linalg import Vectors
     >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
     ...         (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
-    >>> df = sqlContext.createDataFrame(data, ["features"])
+    >>> df = spark.createDataFrame(data, ["features"])
     >>> kmeans = KMeans(k=2, seed=1)
     >>> model = kmeans.fit(df)
     >>> centers = model.clusterCenters()
@@ -350,7 +350,7 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, 
HasPredictionCol, HasMaxIte
     >>> from pyspark.ml.linalg import Vectors
     >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
     ...         (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
-    >>> df = sqlContext.createDataFrame(data, ["features"])
+    >>> df = spark.createDataFrame(data, ["features"])
     >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0)
     >>> model = bkm.fit(df)
     >>> centers = model.clusterCenters()
@@ -627,7 +627,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, 
HasSeed, HasCheckpointInter
 
     >>> from pyspark.ml.linalg import Vectors, SparseVector
     >>> from pyspark.ml.clustering import LDA
-    >>> df = sqlContext.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
+    >>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
     ...      [2, SparseVector(2, {0: 1.0})],], ["id", "features"])
     >>> lda = LDA(k=2, seed=1, optimizer="em")
     >>> model = lda.fit(df)
@@ -933,21 +933,23 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, 
HasSeed, HasCheckpointInter
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.clustering
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = pyspark.ml.clustering.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.clustering tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.clustering tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     import tempfile
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        sc.stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 16029dc..b8b2b37 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -114,7 +114,7 @@ class BinaryClassificationEvaluator(JavaEvaluator, 
HasLabelCol, HasRawPrediction
     >>> from pyspark.ml.linalg import Vectors
     >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), 
x[1]),
     ...    [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 
1.0), (0.8, 1.0)])
-    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
+    >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
     ...
     >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw")
     >>> evaluator.evaluate(dataset)
@@ -181,7 +181,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, 
HasPredictionCol):
 
     >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
     ...   (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
-    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
+    >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
     ...
     >>> evaluator = RegressionEvaluator(predictionCol="raw")
     >>> evaluator.evaluate(dataset)
@@ -253,7 +253,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, 
HasLabelCol, HasPredictio
 
     >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
     ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 
0.0)]
-    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", 
"label"])
+    >>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", 
"label"])
     ...
     >>> evaluator = 
MulticlassClassificationEvaluator(predictionCol="prediction")
     >>> evaluator.evaluate(dataset)
@@ -313,17 +313,19 @@ class MulticlassClassificationEvaluator(JavaEvaluator, 
HasLabelCol, HasPredictio
 
 if __name__ == "__main__":
     import doctest
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.evaluation tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.evaluation tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     (failure_count, test_count) = doctest.testmod(
         globs=globs, optionflags=doctest.ELLIPSIS)
-    sc.stop()
+    spark.stop()
     if failure_count:
         exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 497f2ad..93745c7 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -66,7 +66,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, 
JavaMLReadable, Java
 
     Binarize a column of continuous features given a threshold.
 
-    >>> df = sqlContext.createDataFrame([(0.5,)], ["values"])
+    >>> df = spark.createDataFrame([(0.5,)], ["values"])
     >>> binarizer = Binarizer(threshold=1.0, inputCol="values", 
outputCol="features")
     >>> binarizer.transform(df).head().features
     0.0
@@ -131,7 +131,7 @@ class Bucketizer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable, Jav
 
     Maps a column of continuous features to a column of feature buckets.
 
-    >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], 
["values"])
+    >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], 
["values"])
     >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
     ...     inputCol="values", outputCol="buckets")
     >>> bucketed = bucketizer.transform(df).collect()
@@ -206,7 +206,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
 
     Extracts a vocabulary from document collections and generates a 
:py:attr:`CountVectorizerModel`.
 
-    >>> df = sqlContext.createDataFrame(
+    >>> df = spark.createDataFrame(
     ...    [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
     ...    ["label", "raw"])
     >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors")
@@ -381,7 +381,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWrit
     <https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II 
Wikipedia>`_.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], 
["vec"])
+    >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], 
["vec"])
     >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec")
     >>> df2 = dct.transform(df1)
     >>> df2.head().resultVec
@@ -448,7 +448,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReada
     by a scalar multiplier.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], 
["values"])
+    >>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], 
["values"])
     >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]),
     ...     inputCol="values", outputCol="eprod")
     >>> ep.transform(df).head().eprod
@@ -516,7 +516,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, 
HasNumFeatures, Java
     it is advisable to use a power of two as the numFeatures parameter;
     otherwise the features will not be mapped evenly to the columns.
 
-    >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["words"])
+    >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])
     >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", 
outputCol="features")
     >>> hashingTF.transform(df).head().features
     SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0})
@@ -583,7 +583,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWritab
     Compute the Inverse Document Frequency (IDF) given a collection of 
documents.
 
     >>> from pyspark.ml.linalg import DenseVector
-    >>> df = sqlContext.createDataFrame([(DenseVector([1.0, 2.0]),),
+    >>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),),
     ...     (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])
     >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf")
     >>> model = idf.fit(df)
@@ -671,7 +671,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Jav
     any sparsity.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([1.0]),), 
(Vectors.dense([2.0]),)], ["a"])
+    >>> df = spark.createDataFrame([(Vectors.dense([1.0]),), 
(Vectors.dense([2.0]),)], ["a"])
     >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled")
     >>> model = maScaler.fit(df)
     >>> model.transform(df).show()
@@ -758,7 +758,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Jav
     transformer will be DenseVector even for sparse input.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), 
(Vectors.dense([2.0]),)], ["a"])
+    >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), 
(Vectors.dense([2.0]),)], ["a"])
     >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled")
     >>> model = mmScaler.fit(df)
     >>> model.originalMin
@@ -889,7 +889,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWr
     When the input array length is less than n (number of elements per 
n-gram), no n-grams are
     returned.
 
-    >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", 
"e"])])
+    >>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", 
"e"])])
     >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
     >>> ngram.transform(df).head()
     Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', 
u'c d', u'd e'])
@@ -963,7 +963,7 @@ class Normalizer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable, Jav
 
     >>> from pyspark.ml.linalg import Vectors
     >>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0})
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], 
["dense", "sparse"])
+    >>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], 
["dense", "sparse"])
     >>> normalizer = Normalizer(p=2.0, inputCol="dense", outputCol="features")
     >>> normalizer.transform(df).head().features
     DenseVector([0.6, -0.8])
@@ -1115,7 +1115,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLRead
     `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, 
y, x * y, y * y)`.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([0.5, 2.0]),)], 
["dense"])
+    >>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])
     >>> px = PolynomialExpansion(degree=2, inputCol="dense", 
outputCol="expanded")
     >>> px.transform(df).head().expanded
     DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
@@ -1182,7 +1182,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, 
HasOutputCol, HasSeed, Jav
     covering all real values. This attempts to find numBuckets partitions 
based on a sample of data,
     but it may find fewer depending on the data sample values.
 
-    >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], 
["values"])
+    >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], 
["values"])
     >>> qds = QuantileDiscretizer(numBuckets=2,
     ...     inputCol="values", outputCol="buckets", seed=123)
     >>> qds.getSeed()
@@ -1272,7 +1272,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable,
     length.
     It returns an array of strings that can be empty.
 
-    >>> df = sqlContext.createDataFrame([("A B  c",)], ["text"])
+    >>> df = spark.createDataFrame([("A B  c",)], ["text"])
     >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words")
     >>> reTokenizer.transform(df).head()
     Row(text=u'A B  c', words=[u'a', u'b', u'c'])
@@ -1400,7 +1400,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, 
JavaMLWritable):
     Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
     where '__THIS__' represents the underlying table of the input dataset.
 
-    >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", 
"v1", "v2"])
+    >>> df = spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", 
"v1", "v2"])
     >>> sqlTrans = SQLTransformer(
     ...     statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM 
__THIS__")
     >>> sqlTrans.transform(df).head()
@@ -1461,7 +1461,7 @@ class StandardScaler(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, J
     statistics on the samples in the training set.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), 
(Vectors.dense([2.0]),)], ["a"])
+    >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), 
(Vectors.dense([2.0]),)], ["a"])
     >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
     >>> model = standardScaler.fit(df)
     >>> model.mean
@@ -1718,7 +1718,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadabl
     A feature transformer that filters out stop words from input.
     Note: null values from input array are preserved unless adding null to 
stopWords explicitly.
 
-    >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["text"])
+    >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"])
     >>> remover = StopWordsRemover(inputCol="text", outputCol="words", 
stopWords=["b"])
     >>> remover.transform(df).head().words == ['a', 'c']
     True
@@ -1810,7 +1810,7 @@ class Tokenizer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable, Java
     A tokenizer that converts the input string to lowercase and then
     splits it by white spaces.
 
-    >>> df = sqlContext.createDataFrame([("a b c",)], ["text"])
+    >>> df = spark.createDataFrame([("a b c",)], ["text"])
     >>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
     >>> tokenizer.transform(df).head()
     Row(text=u'a b c', words=[u'a', u'b', u'c'])
@@ -1864,7 +1864,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, 
HasOutputCol, JavaMLReadabl
 
     A feature transformer that merges multiple columns into a vector column.
 
-    >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"])
+    >>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"])
     >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], 
outputCol="features")
     >>> vecAssembler.transform(df).head().features
     DenseVector([1.0, 0.0, 3.0])
@@ -1944,7 +1944,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Ja
       - Add option for allowing unknown categories.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
+    >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
     ...     (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"])
     >>> indexer = VectorIndexer(maxCategories=2, inputCol="a", 
outputCol="indexed")
     >>> model = indexer.fit(df)
@@ -2074,7 +2074,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable, J
     followed by the selected names (in the order given).
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),
     ...     (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),
     ...     (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])
@@ -2161,7 +2161,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, 
HasSeed, HasInputCol, Has
     natural language processing or machine learning process.
 
     >>> sent = ("a b " * 100 + "a c " * 10).split(" ")
-    >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
+    >>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
     >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", 
outputCol="model")
     >>> model = word2Vec.fit(doc)
     >>> model.getVectors().show()
@@ -2345,7 +2345,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWritab
     >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),
     ...     (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
     ...     (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
-    >>> df = sqlContext.createDataFrame(data,["features"])
+    >>> df = spark.createDataFrame(data,["features"])
     >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
     >>> model = pca.fit(df)
     >>> model.transform(df).collect()[0].pca_features
@@ -2447,7 +2447,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, 
HasLabelCol, JavaMLReadable, JavaM
     operators, including '~', '.', ':', '+', and '-'. Also see the `R formula 
docs
     <http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html>`_.
 
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, 1.0, "a"),
     ...     (0.0, 2.0, "b"),
     ...     (0.0, 0.0, "a")
@@ -2561,7 +2561,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, 
HasOutputCol, HasLabelCol, Ja
     categorical label.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame(
+    >>> df = spark.createDataFrame(
     ...    [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
     ...     (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
     ...     (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],
@@ -2656,8 +2656,7 @@ if __name__ == "__main__":
     import tempfile
 
     import pyspark.ml.feature
-    from pyspark.context import SparkContext
-    from pyspark.sql import Row, SQLContext
+    from pyspark.sql import Row, SparkSession
 
     globs = globals().copy()
     features = pyspark.ml.feature.__dict__.copy()
@@ -2665,19 +2664,22 @@ if __name__ == "__main__":
 
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.feature tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.feature tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"),
                                Row(id=2, label="c"), Row(id=3, label="a"),
                                Row(id=4, label="a"), Row(id=5, label="c")], 2)
-    globs['stringIndDf'] = sqlContext.createDataFrame(testData)
+    globs['stringIndDf'] = spark.createDataFrame(testData)
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        sc.stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index 86c00d9..bac2a30 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -65,7 +65,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, 
HasPredictionCol, Ha
     indicated user preferences rather than explicit ratings given to
     items.
 
-    >>> df = sqlContext.createDataFrame(
+    >>> df = spark.createDataFrame(
     ...     [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), 
(2, 2, 5.0)],
     ...     ["user", "item", "rating"])
     >>> als = ALS(rank=10, maxIter=5)
@@ -74,7 +74,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, 
HasPredictionCol, Ha
     10
     >>> model.userFactors.orderBy("id").collect()
     [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
-    >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", 
"item"])
+    >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", 
"item"])
     >>> predictions = sorted(model.transform(test).collect(), key=lambda r: 
r[0])
     >>> predictions[0]
     Row(user=0, item=2, prediction=-0.13807615637779236)
@@ -370,21 +370,23 @@ class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.recommendation
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = pyspark.ml.recommendation.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.recommendation tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.recommendation tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     import tempfile
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        sc.stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index e21dd83..8f58594 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -55,19 +55,19 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPrediction
      - L2 + L1 (elastic net)
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, 2.0, Vectors.dense(1.0)),
     ...     (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", 
"features"])
     >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", 
weightCol="weight")
     >>> model = lr.fit(df)
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
     True
     >>> abs(model.coefficients[0] - 1.0) < 0.001
     True
     >>> abs(model.intercept - 0.0) < 0.001
     True
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> abs(model.transform(test1).head().prediction - 1.0) < 0.001
     True
     >>> lr.setParams("vector")
@@ -413,12 +413,12 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
     Only univariate (single feature) algorithm supported.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> ir = IsotonicRegression()
     >>> model = ir.fit(df)
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
     >>> model.boundaries
@@ -643,7 +643,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredi
     It supports both continuous and categorical features.
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
@@ -654,10 +654,10 @@ class DecisionTreeRegressor(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPredi
     3
     >>> model.featureImportances
     SparseVector(1, {0: 1.0})
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
     >>> dtr_path = temp_path + "/dtr"
@@ -809,7 +809,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredi
 
     >>> from numpy import allclose
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
@@ -818,10 +818,10 @@ class RandomForestRegressor(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPredi
     SparseVector(1, {0: 1.0})
     >>> allclose(model.treeWeights, [1.0, 1.0])
     True
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     0.5
     >>> rfr_path = temp_path + "/rfr"
@@ -921,7 +921,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol,
 
     >>> from numpy import allclose
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
@@ -932,10 +932,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol,
     SparseVector(1, {0: 1.0})
     >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
     True
-    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
+    >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
-    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
+    >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
     >>> gbtr_path = temp_path + "gbtr"
@@ -1056,7 +1056,7 @@ class AFTSurvivalRegression(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPredi
     .. seealso:: `AFT Model 
<https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(1.0), 1.0),
     ...     (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", 
"censor"])
     >>> aftsr = AFTSurvivalRegression()
@@ -1257,7 +1257,7 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
     .. seealso:: `GLM 
<https://en.wikipedia.org/wiki/Generalized_linear_model>`_
 
     >>> from pyspark.ml.linalg import Vectors
-    >>> df = sqlContext.createDataFrame([
+    >>> df = spark.createDataFrame([
     ...     (1.0, Vectors.dense(0.0, 0.0)),
     ...     (1.0, Vectors.dense(1.0, 2.0)),
     ...     (2.0, Vectors.dense(0.0, 0.0)),
@@ -1603,21 +1603,23 @@ class 
GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSumm
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.regression
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = pyspark.ml.regression.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.regression tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.regression tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     import tempfile
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        sc.stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 4f7a6b0..fe87b6c 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -152,7 +152,7 @@ class CrossValidator(Estimator, ValidatorParams):
     >>> from pyspark.ml.classification import LogisticRegression
     >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
     >>> from pyspark.ml.linalg import Vectors
-    >>> dataset = sqlContext.createDataFrame(
+    >>> dataset = spark.createDataFrame(
     ...     [(Vectors.dense([0.0]), 0.0),
     ...      (Vectors.dense([0.4]), 1.0),
     ...      (Vectors.dense([0.5]), 0.0),
@@ -311,7 +311,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
     >>> from pyspark.ml.classification import LogisticRegression
     >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
     >>> from pyspark.ml.linalg import Vectors
-    >>> dataset = sqlContext.createDataFrame(
+    >>> dataset = spark.createDataFrame(
     ...     [(Vectors.dense([0.0]), 0.0),
     ...      (Vectors.dense([0.4]), 1.0),
     ...      (Vectors.dense([0.5]), 0.0),
@@ -456,17 +456,19 @@ class TrainValidationSplitModel(Model, ValidatorParams):
 if __name__ == "__main__":
     import doctest
 
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
 
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    sc = SparkContext("local[2]", "ml.tuning tests")
-    sqlContext = SQLContext(sc)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("ml.tuning tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = sqlContext
+    globs['spark'] = spark
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    sc.stop()
+    spark.stop()
     if failure_count:
         exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py 
b/python/pyspark/mllib/classification.py
index fe5b684..f186217 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -756,12 +756,16 @@ class 
StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
     import pyspark.mllib.classification
     globs = pyspark.mllib.classification.__dict__.copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.classification tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/evaluation.py 
b/python/pyspark/mllib/evaluation.py
index 22e68ea..5f32f09 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -516,12 +516,16 @@ class MultilabelMetrics(JavaModelWrapper):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
     import pyspark.mllib.evaluation
     globs = pyspark.mllib.evaluation.__dict__.copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest')
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.evaluation tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 90559f6..e31c75c 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -732,11 +732,15 @@ class ElementwiseProduct(VectorTransformer):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.feature tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/fpm.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index f339e50..ab4066f 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -183,16 +183,21 @@ class PrefixSpan(object):
 
 def _test():
     import doctest
+    from pyspark.sql import SparkSession
     import pyspark.mllib.fpm
     globs = pyspark.mllib.fpm.__dict__.copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest')
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.fpm tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     import tempfile
 
     temp_path = tempfile.mkdtemp()
     globs['temp_path'] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-        globs['sc'].stop()
+        spark.stop()
     finally:
         from shutil import rmtree
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/linalg/distributed.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg/distributed.py 
b/python/pyspark/mllib/linalg/distributed.py
index af34ce3..ea4f27c 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -1184,16 +1184,18 @@ class BlockMatrix(DistributedMatrix):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     from pyspark.mllib.linalg import Matrices
     import pyspark.mllib.linalg.distributed
     globs = pyspark.mllib.linalg.distributed.__dict__.copy()
-    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
-    globs['sqlContext'] = SQLContext(globs['sc'])
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("mllib.linalg.distributed tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     globs['Matrices'] = Matrices
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/random.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 6a3c643..61213dd 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -409,13 +409,17 @@ class RandomRDDs(object):
 
 def _test():
     import doctest
-    from pyspark.context import SparkContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("mllib.random tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/regression.py 
b/python/pyspark/mllib/regression.py
index 639c5ea..43d9072 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -824,12 +824,16 @@ class 
StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
     import pyspark.mllib.regression
     globs = pyspark.mllib.regression.__dict__.copy()
-    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("mllib.regression tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/stat/_statistics.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/stat/_statistics.py 
b/python/pyspark/mllib/stat/_statistics.py
index 36c8f48..b0a8524 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -306,11 +306,15 @@ class Statistics(object):
 
 def _test():
     import doctest
-    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.stat.statistics tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index f7ea466..8be76fc 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -657,9 +657,14 @@ class GradientBoostedTrees(object):
 def _test():
     import doctest
     globs = globals().copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    from pyspark.sql import SparkSession
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("mllib.tree tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/mllib/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 39bc658..a316ee1 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -347,13 +347,17 @@ class LinearDataGenerator(object):
 
 def _test():
     import doctest
-    from pyspark.context import SparkContext
+    from pyspark.sql import SparkSession
     globs = globals().copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
+    spark = SparkSession.builder\
+        .master("local[2]")\
+        .appName("mllib.util tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
     (failure_count, test_count) = doctest.testmod(globs=globs, 
optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 812dbba..3033f14 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -244,21 +244,23 @@ class Catalog(object):
 def _test():
     import os
     import doctest
-    from pyspark.context import SparkContext
-    from pyspark.sql.session import SparkSession
+    from pyspark.sql import SparkSession
     import pyspark.sql.catalog
 
     os.chdir(os.environ["SPARK_HOME"])
 
     globs = pyspark.sql.catalog.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
-    globs['sc'] = sc
-    globs['spark'] = SparkSession(sc)
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.catalog tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
+    globs['spark'] = spark
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.catalog,
         globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/sql/column.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 5b26e94..4b99f30 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -434,13 +434,15 @@ class Column(object):
 
 def _test():
     import doctest
-    from pyspark.context import SparkContext
-    from pyspark.sql import SQLContext
+    from pyspark.sql import SparkSession
     import pyspark.sql.column
     globs = pyspark.sql.column.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.column tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = SQLContext(sc)
     globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
         .toDF(StructType([StructField('age', IntegerType()),
                           StructField('name', StringType())]))
@@ -448,7 +450,7 @@ def _test():
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.column, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | 
doctest.REPORT_NDIFF)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/sql/conf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 609d882..792c420 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -71,11 +71,14 @@ def _test():
     os.chdir(os.environ["SPARK_HOME"])
 
     globs = pyspark.sql.conf.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
-    globs['sc'] = sc
-    globs['spark'] = SparkSession(sc)
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.conf tests")\
+        .getOrCreate()
+    globs['sc'] = spark.sparkContext
+    globs['spark'] = spark
     (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, 
globs=globs)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 716b16f..1f15eec 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -212,7 +212,7 @@ def broadcast(df):
 def coalesce(*cols):
     """Returns the first column that is not null.
 
-    >>> cDf = sqlContext.createDataFrame([(None, None), (1, None), (None, 2)], 
("a", "b"))
+    >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], 
("a", "b"))
     >>> cDf.show()
     +----+----+
     |   a|   b|
@@ -252,7 +252,7 @@ def corr(col1, col2):
 
     >>> a = range(20)
     >>> b = [2 * x for x in range(20)]
-    >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"])
+    >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
     >>> df.agg(corr("a", "b").alias('c')).collect()
     [Row(c=1.0)]
     """
@@ -267,7 +267,7 @@ def covar_pop(col1, col2):
 
     >>> a = [1] * 10
     >>> b = [1] * 10
-    >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"])
+    >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
     >>> df.agg(covar_pop("a", "b").alias('c')).collect()
     [Row(c=0.0)]
     """
@@ -282,7 +282,7 @@ def covar_samp(col1, col2):
 
     >>> a = [1] * 10
     >>> b = [1] * 10
-    >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"])
+    >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
     >>> df.agg(covar_samp("a", "b").alias('c')).collect()
     [Row(c=0.0)]
     """
@@ -373,7 +373,7 @@ def input_file_name():
 def isnan(col):
     """An expression that returns true iff the column is NaN.
 
-    >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 
2.0)], ("a", "b"))
+    >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], 
("a", "b"))
     >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect()
     [Row(r1=False, r2=False), Row(r1=True, r2=True)]
     """
@@ -385,7 +385,7 @@ def isnan(col):
 def isnull(col):
     """An expression that returns true iff the column is null.
 
-    >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b"))
+    >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b"))
     >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect()
     [Row(r1=False, r2=False), Row(r1=True, r2=True)]
     """
@@ -432,7 +432,7 @@ def nanvl(col1, col2):
 
     Both inputs should be floating point columns (DoubleType or FloatType).
 
-    >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 
2.0)], ("a", "b"))
+    >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], 
("a", "b"))
     >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, 
df.b).alias("r2")).collect()
     [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)]
     """
@@ -470,7 +470,7 @@ def round(col, scale=0):
     Round the given value to `scale` decimal places using HALF_UP rounding 
mode if `scale` >= 0
     or at integral part when `scale` < 0.
 
-    >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 
0).alias('r')).collect()
+    >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 
0).alias('r')).collect()
     [Row(r=3.0)]
     """
     sc = SparkContext._active_spark_context
@@ -483,7 +483,7 @@ def bround(col, scale=0):
     Round the given value to `scale` decimal places using HALF_EVEN rounding 
mode if `scale` >= 0
     or at integral part when `scale` < 0.
 
-    >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 
0).alias('r')).collect()
+    >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 
0).alias('r')).collect()
     [Row(r=2.0)]
     """
     sc = SparkContext._active_spark_context
@@ -494,7 +494,7 @@ def bround(col, scale=0):
 def shiftLeft(col, numBits):
     """Shift the given value numBits left.
 
-    >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 
1).alias('r')).collect()
+    >>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 
1).alias('r')).collect()
     [Row(r=42)]
     """
     sc = SparkContext._active_spark_context
@@ -505,7 +505,7 @@ def shiftLeft(col, numBits):
 def shiftRight(col, numBits):
     """Shift the given value numBits right.
 
-    >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 
1).alias('r')).collect()
+    >>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 
1).alias('r')).collect()
     [Row(r=21)]
     """
     sc = SparkContext._active_spark_context
@@ -517,7 +517,7 @@ def shiftRight(col, numBits):
 def shiftRightUnsigned(col, numBits):
     """Unsigned shift the given value numBits right.
 
-    >>> df = sqlContext.createDataFrame([(-42,)], ['a'])
+    >>> df = spark.createDataFrame([(-42,)], ['a'])
     >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
     [Row(r=9223372036854775787)]
     """
@@ -575,7 +575,7 @@ def greatest(*cols):
     Returns the greatest value of the list of column names, skipping null 
values.
     This function takes at least 2 parameters. It will return null iff all 
parameters are null.
 
-    >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
+    >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
     >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
     [Row(greatest=4)]
     """
@@ -591,7 +591,7 @@ def least(*cols):
     Returns the least value of the list of column names, skipping null values.
     This function takes at least 2 parameters. It will return null iff all 
parameters are null.
 
-    >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
+    >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
     >>> df.select(least(df.a, df.b, df.c).alias("least")).collect()
     [Row(least=1)]
     """
@@ -647,7 +647,7 @@ def log(arg1, arg2=None):
 def log2(col):
     """Returns the base-2 logarithm of the argument.
 
-    >>> sqlContext.createDataFrame([(4,)], 
['a']).select(log2('a').alias('log2')).collect()
+    >>> spark.createDataFrame([(4,)], 
['a']).select(log2('a').alias('log2')).collect()
     [Row(log2=2.0)]
     """
     sc = SparkContext._active_spark_context
@@ -660,7 +660,7 @@ def conv(col, fromBase, toBase):
     """
     Convert a number in a string column from one base to another.
 
-    >>> df = sqlContext.createDataFrame([("010101",)], ['n'])
+    >>> df = spark.createDataFrame([("010101",)], ['n'])
     >>> df.select(conv(df.n, 2, 16).alias('hex')).collect()
     [Row(hex=u'15')]
     """
@@ -673,7 +673,7 @@ def factorial(col):
     """
     Computes the factorial of the given value.
 
-    >>> df = sqlContext.createDataFrame([(5,)], ['n'])
+    >>> df = spark.createDataFrame([(5,)], ['n'])
     >>> df.select(factorial(df.n).alias('f')).collect()
     [Row(f=120)]
     """
@@ -765,7 +765,7 @@ def date_format(date, format):
     NOTE: Use when ever possible specialized functions like `year`. These 
benefit from a
     specialized implementation.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect()
     [Row(date=u'04/08/2015')]
     """
@@ -778,7 +778,7 @@ def year(col):
     """
     Extract the year of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(year('a').alias('year')).collect()
     [Row(year=2015)]
     """
@@ -791,7 +791,7 @@ def quarter(col):
     """
     Extract the quarter of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(quarter('a').alias('quarter')).collect()
     [Row(quarter=2)]
     """
@@ -804,7 +804,7 @@ def month(col):
     """
     Extract the month of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(month('a').alias('month')).collect()
     [Row(month=4)]
    """
@@ -817,7 +817,7 @@ def dayofmonth(col):
     """
     Extract the day of the month of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(dayofmonth('a').alias('day')).collect()
     [Row(day=8)]
     """
@@ -830,7 +830,7 @@ def dayofyear(col):
     """
     Extract the day of the year of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(dayofyear('a').alias('day')).collect()
     [Row(day=98)]
     """
@@ -843,7 +843,7 @@ def hour(col):
     """
     Extract the hours of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
     >>> df.select(hour('a').alias('hour')).collect()
     [Row(hour=13)]
     """
@@ -856,7 +856,7 @@ def minute(col):
     """
     Extract the minutes of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
     >>> df.select(minute('a').alias('minute')).collect()
     [Row(minute=8)]
     """
@@ -869,7 +869,7 @@ def second(col):
     """
     Extract the seconds of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
     >>> df.select(second('a').alias('second')).collect()
     [Row(second=15)]
     """
@@ -882,7 +882,7 @@ def weekofyear(col):
     """
     Extract the week number of a given date as integer.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
     >>> df.select(weekofyear(df.a).alias('week')).collect()
     [Row(week=15)]
     """
@@ -895,7 +895,7 @@ def date_add(start, days):
     """
     Returns the date that is `days` days after `start`
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
     >>> df.select(date_add(df.d, 1).alias('d')).collect()
     [Row(d=datetime.date(2015, 4, 9))]
     """
@@ -908,7 +908,7 @@ def date_sub(start, days):
     """
     Returns the date that is `days` days before `start`
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
     >>> df.select(date_sub(df.d, 1).alias('d')).collect()
     [Row(d=datetime.date(2015, 4, 7))]
     """
@@ -921,7 +921,7 @@ def datediff(end, start):
     """
     Returns the number of days from `start` to `end`.
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 
'd2'])
+    >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
     >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()
     [Row(diff=32)]
     """
@@ -934,7 +934,7 @@ def add_months(start, months):
     """
     Returns the date that is `months` months after `start`
 
-    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+    >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
     >>> df.select(add_months(df.d, 1).alias('d')).collect()
     [Row(d=datetime.date(2015, 5, 8))]
     """
@@ -947,7 +947,7 @@ def months_between(date1, date2):
     """
     Returns the number of months between date1 and date2.
 
-    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', 
'1996-10-30')], ['t', 'd'])
+    >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], 
['t', 'd'])
     >>> df.select(months_between(df.t, df.d).alias('months')).collect()
     [Row(months=3.9495967...)]
     """
@@ -960,7 +960,7 @@ def to_date(col):
     """
     Converts the column of StringType or TimestampType into DateType.
 
-    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
     >>> df.select(to_date(df.t).alias('date')).collect()
     [Row(date=datetime.date(1997, 2, 28))]
     """
@@ -975,7 +975,7 @@ def trunc(date, format):
 
     :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
 
-    >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
+    >>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
     >>> df.select(trunc(df.d, 'year').alias('year')).collect()
     [Row(year=datetime.date(1997, 1, 1))]
     >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
@@ -993,7 +993,7 @@ def next_day(date, dayOfWeek):
     Day of the week parameter is case insensitive, and accepts:
         "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
 
-    >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d'])
+    >>> df = spark.createDataFrame([('2015-07-27',)], ['d'])
     >>> df.select(next_day(df.d, 'Sun').alias('date')).collect()
     [Row(date=datetime.date(2015, 8, 2))]
     """
@@ -1006,7 +1006,7 @@ def last_day(date):
     """
     Returns the last day of the month which the given date belongs to.
 
-    >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d'])
+    >>> df = spark.createDataFrame([('1997-02-10',)], ['d'])
     >>> df.select(last_day(df.d).alias('date')).collect()
     [Row(date=datetime.date(1997, 2, 28))]
     """
@@ -1045,7 +1045,7 @@ def from_utc_timestamp(timestamp, tz):
     """
     Assumes given timestamp is UTC and converts to given timezone.
 
-    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
     >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect()
     [Row(t=datetime.datetime(1997, 2, 28, 2, 30))]
     """
@@ -1058,7 +1058,7 @@ def to_utc_timestamp(timestamp, tz):
     """
     Assumes given timestamp is in given timezone and converts to UTC.
 
-    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
     >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect()
     [Row(t=datetime.datetime(1997, 2, 28, 18, 30))]
     """
@@ -1087,7 +1087,7 @@ def window(timeColumn, windowDuration, 
slideDuration=None, startTime=None):
     The output column will be a struct called 'window' by default with the 
nested columns 'start'
     and 'end', where 'start' and 'end' will be of `TimestampType`.
 
-    >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 
1)]).toDF("date", "val")
+    >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", 
"val")
     >>> w = df.groupBy(window("date", "5 
seconds")).agg(sum("val").alias("sum"))
     >>> w.select(w.window.start.cast("string").alias("start"),
     ...          w.window.end.cast("string").alias("end"), "sum").collect()
@@ -1124,7 +1124,7 @@ def crc32(col):
     Calculates the cyclic redundancy check value  (CRC32) of a binary column 
and
     returns the value as a bigint.
 
-    >>> sqlContext.createDataFrame([('ABC',)], 
['a']).select(crc32('a').alias('crc32')).collect()
+    >>> spark.createDataFrame([('ABC',)], 
['a']).select(crc32('a').alias('crc32')).collect()
     [Row(crc32=2743272264)]
     """
     sc = SparkContext._active_spark_context
@@ -1136,7 +1136,7 @@ def crc32(col):
 def md5(col):
     """Calculates the MD5 digest and returns the value as a 32 character hex 
string.
 
-    >>> sqlContext.createDataFrame([('ABC',)], 
['a']).select(md5('a').alias('hash')).collect()
+    >>> spark.createDataFrame([('ABC',)], 
['a']).select(md5('a').alias('hash')).collect()
     [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
     """
     sc = SparkContext._active_spark_context
@@ -1149,7 +1149,7 @@ def md5(col):
 def sha1(col):
     """Returns the hex string result of SHA-1.
 
-    >>> sqlContext.createDataFrame([('ABC',)], 
['a']).select(sha1('a').alias('hash')).collect()
+    >>> spark.createDataFrame([('ABC',)], 
['a']).select(sha1('a').alias('hash')).collect()
     [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
     """
     sc = SparkContext._active_spark_context
@@ -1179,7 +1179,7 @@ def sha2(col, numBits):
 def hash(*cols):
     """Calculates the hash code of given columns, and returns the result as a 
int column.
 
-    >>> sqlContext.createDataFrame([('ABC',)], 
['a']).select(hash('a').alias('hash')).collect()
+    >>> spark.createDataFrame([('ABC',)], 
['a']).select(hash('a').alias('hash')).collect()
     [Row(hash=-757602832)]
     """
     sc = SparkContext._active_spark_context
@@ -1215,7 +1215,7 @@ def concat(*cols):
     """
     Concatenates multiple input string columns together into a single string 
column.
 
-    >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
+    >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
     >>> df.select(concat(df.s, df.d).alias('s')).collect()
     [Row(s=u'abcd123')]
     """
@@ -1230,7 +1230,7 @@ def concat_ws(sep, *cols):
     Concatenates multiple input string columns together into a single string 
column,
     using the given separator.
 
-    >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
+    >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
     >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
     [Row(s=u'abcd-123')]
     """
@@ -1268,7 +1268,7 @@ def format_number(col, d):
     :param col: the column name of the numeric value to be formatted
     :param d: the N decimal places
 
-    >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 
4).alias('v')).collect()
+    >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 
4).alias('v')).collect()
     [Row(v=u'5.0000')]
     """
     sc = SparkContext._active_spark_context
@@ -1284,7 +1284,7 @@ def format_string(format, *cols):
     :param col: the column name of the numeric value to be formatted
     :param d: the N decimal places
 
-    >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b'])
+    >>> df = spark.createDataFrame([(5, "hello")], ['a', 'b'])
     >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
     [Row(v=u'5 hello')]
     """
@@ -1301,7 +1301,7 @@ def instr(str, substr):
     NOTE: The position is not zero based, but 1 based index, returns 0 if 
substr
     could not be found in str.
 
-    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(instr(df.s, 'b').alias('s')).collect()
     [Row(s=2)]
     """
@@ -1317,7 +1317,7 @@ def substring(str, pos, len):
     returns the slice of byte array that starts at `pos` in byte and is of 
length `len`
     when str is Binary type
 
-    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
     [Row(s=u'ab')]
     """
@@ -1334,7 +1334,7 @@ def substring_index(str, delim, count):
     returned. If count is negative, every to the right of the final delimiter 
(counting from the
     right) is returned. substring_index performs a case-sensitive match when 
searching for delim.
 
-    >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
+    >>> df = spark.createDataFrame([('a.b.c.d',)], ['s'])
     >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
     [Row(s=u'a.b')]
     >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
@@ -1349,7 +1349,7 @@ def substring_index(str, delim, count):
 def levenshtein(left, right):
     """Computes the Levenshtein distance of the two given strings.
 
-    >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
+    >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
     >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
     [Row(d=3)]
     """
@@ -1370,7 +1370,7 @@ def locate(substr, str, pos=0):
     :param str: a Column of StringType
     :param pos: start position (zero based)
 
-    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(locate('b', df.s, 1).alias('s')).collect()
     [Row(s=2)]
     """
@@ -1384,7 +1384,7 @@ def lpad(col, len, pad):
     """
     Left-pad the string column to width `len` with `pad`.
 
-    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(lpad(df.s, 6, '#').alias('s')).collect()
     [Row(s=u'##abcd')]
     """
@@ -1398,7 +1398,7 @@ def rpad(col, len, pad):
     """
     Right-pad the string column to width `len` with `pad`.
 
-    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(rpad(df.s, 6, '#').alias('s')).collect()
     [Row(s=u'abcd##')]
     """
@@ -1412,7 +1412,7 @@ def repeat(col, n):
     """
     Repeats a string column n times, and returns it as a new string column.
 
-    >>> df = sqlContext.createDataFrame([('ab',)], ['s',])
+    >>> df = spark.createDataFrame([('ab',)], ['s',])
     >>> df.select(repeat(df.s, 3).alias('s')).collect()
     [Row(s=u'ababab')]
     """
@@ -1428,7 +1428,7 @@ def split(str, pattern):
 
     NOTE: pattern is a string represent the regular expression.
 
-    >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',])
+    >>> df = spark.createDataFrame([('ab12cd',)], ['s',])
     >>> df.select(split(df.s, '[0-9]+').alias('s')).collect()
     [Row(s=[u'ab', u'cd'])]
     """
@@ -1441,7 +1441,7 @@ def split(str, pattern):
 def regexp_extract(str, pattern, idx):
     """Extract a specific(idx) group identified by a java regex, from the 
specified string column.
 
-    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+    >>> df = spark.createDataFrame([('100-200',)], ['str'])
     >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
     [Row(d=u'100')]
     """
@@ -1455,7 +1455,7 @@ def regexp_extract(str, pattern, idx):
 def regexp_replace(str, pattern, replacement):
     """Replace all substrings of the specified string value that match regexp 
with rep.
 
-    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+    >>> df = spark.createDataFrame([('100-200',)], ['str'])
     >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect()
     [Row(d=u'-----')]
     """
@@ -1469,7 +1469,7 @@ def regexp_replace(str, pattern, replacement):
 def initcap(col):
     """Translate the first letter of each word to upper case in the sentence.
 
-    >>> sqlContext.createDataFrame([('ab cd',)], 
['a']).select(initcap("a").alias('v')).collect()
+    >>> spark.createDataFrame([('ab cd',)], 
['a']).select(initcap("a").alias('v')).collect()
     [Row(v=u'Ab Cd')]
     """
     sc = SparkContext._active_spark_context
@@ -1482,7 +1482,7 @@ def soundex(col):
     """
     Returns the SoundEx encoding for a string
 
-    >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
+    >>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
     >>> df.select(soundex(df.name).alias("soundex")).collect()
     [Row(soundex=u'P362'), Row(soundex=u'U612')]
     """
@@ -1509,7 +1509,7 @@ def hex(col):
     """Computes hex value of the given column, which could be StringType,
     BinaryType, IntegerType or LongType.
 
-    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), 
hex('b')).collect()
+    >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), 
hex('b')).collect()
     [Row(hex(a)=u'414243', hex(b)=u'3')]
     """
     sc = SparkContext._active_spark_context
@@ -1523,7 +1523,7 @@ def unhex(col):
     """Inverse of hex. Interprets each pair of characters as a hexadecimal 
number
     and converts to the byte representation of number.
 
-    >>> sqlContext.createDataFrame([('414243',)], 
['a']).select(unhex('a')).collect()
+    >>> spark.createDataFrame([('414243',)], 
['a']).select(unhex('a')).collect()
     [Row(unhex(a)=bytearray(b'ABC'))]
     """
     sc = SparkContext._active_spark_context
@@ -1535,7 +1535,7 @@ def unhex(col):
 def length(col):
     """Calculates the length of a string or binary expression.
 
-    >>> sqlContext.createDataFrame([('ABC',)], 
['a']).select(length('a').alias('length')).collect()
+    >>> spark.createDataFrame([('ABC',)], 
['a']).select(length('a').alias('length')).collect()
     [Row(length=3)]
     """
     sc = SparkContext._active_spark_context
@@ -1550,7 +1550,7 @@ def translate(srcCol, matching, replace):
     The translate will happen when any character in the string matching with 
the character
     in the `matching`.
 
-    >>> sqlContext.createDataFrame([('translate',)], 
['a']).select(translate('a', "rnlt", "123")\
+    >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', 
"rnlt", "123")\
     .alias('r')).collect()
     [Row(r=u'1a2s3ae')]
     """
@@ -1608,7 +1608,7 @@ def array_contains(col, value):
     :param col: name of column containing array
     :param value: value to check for in array
 
-    >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+    >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
     >>> df.select(array_contains(df.data, "a")).collect()
     [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
     """
@@ -1621,7 +1621,7 @@ def explode(col):
     """Returns a new row for each element in the given array or map.
 
     >>> from pyspark.sql import Row
-    >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], 
mapfield={"a": "b"})])
+    >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": 
"b"})])
     >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
     [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
 
@@ -1648,7 +1648,7 @@ def get_json_object(col, path):
     :param path: path to the json object to extract
 
     >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": 
"value12"}''')]
-    >>> df = sqlContext.createDataFrame(data, ("key", "jstring"))
+    >>> df = spark.createDataFrame(data, ("key", "jstring"))
     >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \
                           get_json_object(df.jstring, '$.f2').alias("c1") 
).collect()
     [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', 
c1=None)]
@@ -1667,7 +1667,7 @@ def json_tuple(col, *fields):
     :param fields: list of fields to extract
 
     >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": 
"value12"}''')]
-    >>> df = sqlContext.createDataFrame(data, ("key", "jstring"))
+    >>> df = spark.createDataFrame(data, ("key", "jstring"))
     >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()
     [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', 
c1=None)]
     """
@@ -1683,7 +1683,7 @@ def size(col):
 
     :param col: name of column or expression
 
-    >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
+    >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
     >>> df.select(size(df.data)).collect()
     [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
     """
@@ -1698,7 +1698,7 @@ def sort_array(col, asc=True):
 
     :param col: name of column or expression
 
-    >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
+    >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
     >>> df.select(sort_array(df.data).alias('r')).collect()
     [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
     >>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
@@ -1775,18 +1775,21 @@ __all__.sort()
 
 def _test():
     import doctest
-    from pyspark.context import SparkContext
-    from pyspark.sql import Row, SQLContext
+    from pyspark.sql import Row, SparkSession
     import pyspark.sql.functions
     globs = pyspark.sql.functions.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.functions tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = SQLContext(sc)
+    globs['spark'] = spark
     globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', 
age=5)]).toDF()
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.functions, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3162b96/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index ee734cb..6987af6 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -195,13 +195,15 @@ class GroupedData(object):
 
 def _test():
     import doctest
-    from pyspark.context import SparkContext
-    from pyspark.sql import Row, SQLContext
+    from pyspark.sql import Row, SparkSession
     import pyspark.sql.group
     globs = pyspark.sql.group.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.group tests")\
+        .getOrCreate()
+    sc = spark.sparkContext
     globs['sc'] = sc
-    globs['sqlContext'] = SQLContext(sc)
     globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
         .toDF(StructType([StructField('age', IntegerType()),
                           StructField('name', StringType())]))
@@ -216,7 +218,7 @@ def _test():
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.group, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | 
doctest.REPORT_NDIFF)
-    globs['sc'].stop()
+    spark.stop()
     if failure_count:
         exit(-1)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to