[SPARK-7899] [PYSPARK] Fix Python 3 pyspark/sql/types module conflict This PR makes the types module in `pyspark/sql/types` work with pylint static analysis by removing the dynamic naming of the `pyspark/sql/_types` module to `pyspark/sql/types`.
Tests are now loaded using `$PYSPARK_DRIVER_PYTHON -m module` rather than `$PYSPARK_DRIVER_PYTHON module.py`. The old method adds the location of `module.py` to `sys.path`, so this change prevents accidental use of relative paths in Python. Author: Michael Nazario <mnaza...@palantir.com> Closes #6439 from mnazario/feature/SPARK-7899 and squashes the following commits: 366ef30 [Michael Nazario] Remove hack on random.py bb8b04d [Michael Nazario] Make doctests consistent with other tests 6ee4f75 [Michael Nazario] Change test scripts to use "-m" 673528f [Michael Nazario] Move _types back to types Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c5b1982 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c5b1982 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c5b1982 Branch: refs/heads/master Commit: 1c5b19827a091b5aba69a967600e7ca35ed3bcfd Parents: 5f48e5c Author: Michael Nazario <mnaza...@palantir.com> Authored: Fri May 29 14:13:44 2015 -0700 Committer: Davies Liu <dav...@databricks.com> Committed: Fri May 29 14:13:44 2015 -0700 ---------------------------------------------------------------------- bin/pyspark | 6 +- python/pyspark/accumulators.py | 4 + python/pyspark/mllib/__init__.py | 8 - python/pyspark/mllib/rand.py | 409 ----------- python/pyspark/mllib/random.py | 409 +++++++++++ python/pyspark/sql/__init__.py | 12 - python/pyspark/sql/_types.py | 1306 --------------------------------- python/pyspark/sql/types.py | 1306 +++++++++++++++++++++++++++++++++ python/run-tests | 76 +- 9 files changed, 1758 insertions(+), 1778 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/bin/pyspark ---------------------------------------------------------------------- diff --git a/bin/pyspark b/bin/pyspark index 8acad61..7cb19c5 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/accumulators.py ---------------------------------------------------------------------- diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a13..adca90d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,7 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + doctest.testmod() http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2..b11aed2 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -28,11 +28,3 @@ if numpy.version.version < '1.4': __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/rand.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py deleted file mode 100644 index 06fbc0e..0000000 --- a/python/pyspark/mllib/rand.py +++ /dev/null @@ -1,409 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Python package for random data generation. -""" - -from functools import wraps - -from pyspark.mllib.common import callMLlibFunc - - -__all__ = ['RandomRDDs', ] - - -def toArray(f): - @wraps(f) - def func(sc, *a, **kw): - rdd = f(sc, *a, **kw) - return rdd.map(lambda vec: vec.toArray()) - return func - - -class RandomRDDs(object): - """ - Generator methods for creating RDDs comprised of i.i.d samples from - some distribution. - """ - - @staticmethod - def uniformRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the - uniform distribution U(0.0, 1.0). - - To transform the distribution in the generated RDD from U(0.0, 1.0) - to U(a, b), use - C{RandomRDDs.uniformRDD(sc, n, p, seed)\ - .map(lambda v: a + (b - a) * v)} - - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. - - >>> x = RandomRDDs.uniformRDD(sc, 100).collect() - >>> len(x) - 100 - >>> max(x) <= 1.0 and min(x) >= 0.0 - True - >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() - 4 - >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() - >>> parts == sc.defaultParallelism - True - """ - return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) - - @staticmethod - def normalRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the standard normal - distribution. - - To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma^2), use - C{RandomRDDs.normal(sc, n, p, seed)\ - .map(lambda v: mean + sigma * v)} - - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) - >>> stats = x.stats() - >>> stats.count() - 1000 - >>> abs(stats.mean() - 0.0) < 0.1 - True - >>> abs(stats.stdev() - 1.0) < 0.1 - True - """ - return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) - - @staticmethod - def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the log normal - distribution with the input mean and standard distribution. - - :param sc: SparkContext used to create the RDD. - :param mean: mean for the log Normal distribution - :param std: std for the log Normal distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std). - - >>> from math import sqrt, exp - >>> mean = 0.0 - >>> std = 1.0 - >>> expMean = exp(mean + 0.5 * std * std) - >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) - >>> stats = x.stats() - >>> stats.count() - 1000 - >>> abs(stats.mean() - expMean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - expStd) < 0.5 - True - """ - return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std), - size, numPartitions, seed) - - @staticmethod - def poissonRDD(sc, mean, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Poisson - distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). - - >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) - >>> stats = x.stats() - >>> stats.count() - 1000 - >>> abs(stats.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) - - @staticmethod - def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Exponential - distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). - - >>> mean = 2.0 - >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) - >>> stats = x.stats() - >>> stats.count() - 1000 - >>> abs(stats.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) - - @staticmethod - def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Gamma - distribution with the input shape and scale. - - :param sc: SparkContext used to create the RDD. - :param shape: shape (> 0) parameter for the Gamma distribution - :param scale: scale (> 0) parameter for the Gamma distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). - - >>> from math import sqrt - >>> shape = 1.0 - >>> scale = 2.0 - >>> expMean = shape * scale - >>> expStd = sqrt(shape * scale * scale) - >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) - >>> stats = x.stats() - >>> stats.count() - 1000 - >>> abs(stats.mean() - expMean) < 0.5 - True - >>> abs(stats.stdev() - expStd) < 0.5 - True - """ - return callMLlibFunc("gammaRDD", sc._jsc, float(shape), - float(scale), size, numPartitions, seed) - - @staticmethod - @toArray - def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the uniform distribution U(0.0, 1.0). - - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD. - :param seed: Seed for the RNG that generates the seed for the generator in each partition. - :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) - >>> mat.shape - (10, 10) - >>> mat.max() <= 1.0 and mat.min() >= 0.0 - True - >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() - 4 - """ - return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the standard normal distribution. - - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - 0.0) < 0.1 - True - >>> abs(mat.std() - 1.0) < 0.1 - True - """ - return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the log normal distribution. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean of the log normal distribution - :param std: Standard Deviation of the log normal distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. - - >>> import numpy as np - >>> from math import sqrt, exp - >>> mean = 0.0 - >>> std = 1.0 - >>> expMean = exp(mean + 0.5 * std * std) - >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() - >>> mat = np.matrix(m) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 - True - >>> abs(mat.std() - expStd) < 0.1 - True - """ - return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std), - numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Poisson distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). - - >>> import numpy as np - >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, - numPartitions, seed) - - @staticmethod - @toArray - def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Exponential distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). - - >>> import numpy as np - >>> mean = 0.5 - >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols, - numPartitions, seed) - - @staticmethod - @toArray - def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Gamma distribution. - - :param sc: SparkContext used to create the RDD. - :param shape: Shape (> 0) of the Gamma distribution - :param scale: Scale (> 0) of the Gamma distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). - - >>> import numpy as np - >>> from math import sqrt - >>> shape = 1.0 - >>> scale = 2.0 - >>> expMean = shape * scale - >>> expStd = sqrt(shape * scale * scale) - >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 - True - >>> abs(mat.std() - expStd) < 0.1 - True - """ - return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale), - numRows, numCols, numPartitions, seed) - - -def _test(): - import doctest - from pyspark.context import SparkContext - globs = globals().copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/random.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 0000000..06fbc0e --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,409 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + +from functools import wraps + +from pyspark.mllib.common import callMLlibFunc + + +__all__ = ['RandomRDDs', ] + + +def toArray(f): + @wraps(f) + def func(sc, *a, **kw): + rdd = f(sc, *a, **kw) + return rdd.map(lambda vec: vec.toArray()) + return func + + +class RandomRDDs(object): + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution U(0.0, 1.0). + + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma^2), use + C{RandomRDDs.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the log normal + distribution with the input mean and standard distribution. + + :param sc: SparkContext used to create the RDD. + :param mean: mean for the log Normal distribution + :param std: std for the log Normal distribution + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std). + + >>> from math import sqrt, exp + >>> mean = 0.0 + >>> std = 1.0 + >>> expMean = exp(mean + 0.5 * std * std) + >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) + >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - expMean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - expStd) < 0.5 + True + """ + return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std), + size, numPartitions, seed) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Poisson + distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + + >>> mean = 100.0 + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) + + @staticmethod + def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Exponential + distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or 1 / lambda, for the Exponential distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). + + >>> mean = 2.0 + >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) + + @staticmethod + def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Gamma + distribution with the input shape and scale. + + :param sc: SparkContext used to create the RDD. + :param shape: shape (> 0) parameter for the Gamma distribution + :param scale: scale (> 0) parameter for the Gamma distribution + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). + + >>> from math import sqrt + >>> shape = 1.0 + >>> scale = 2.0 + >>> expMean = shape * scale + >>> expStd = sqrt(shape * scale * scale) + >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - expMean) < 0.5 + True + >>> abs(stats.stdev() - expStd) < 0.5 + True + """ + return callMLlibFunc("gammaRDD", sc._jsc, float(shape), + float(scale), size, numPartitions, seed) + + @staticmethod + @toArray + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the standard normal distribution. + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the log normal distribution. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean of the log normal distribution + :param std: Standard Deviation of the log normal distribution + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. + + >>> import numpy as np + >>> from math import sqrt, exp + >>> mean = 0.0 + >>> std = 1.0 + >>> expMean = exp(mean + 0.5 * std * std) + >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) + >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() + >>> mat = np.matrix(m) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - expMean) < 0.1 + True + >>> abs(mat.std() - expStd) < 0.1 + True + """ + return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std), + numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Poisson distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) + + @staticmethod + @toArray + def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Exponential distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or 1 / lambda, for the Exponential distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). + + >>> import numpy as np + >>> mean = 0.5 + >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) + + @staticmethod + @toArray + def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Gamma distribution. + + :param sc: SparkContext used to create the RDD. + :param shape: Shape (> 0) of the Gamma distribution + :param scale: Scale (> 0) of the Gamma distribution + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). + + >>> import numpy as np + >>> from math import sqrt + >>> shape = 1.0 + >>> scale = 2.0 + >>> expMean = shape * scale + >>> expStd = sqrt(shape * scale * scale) + >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - expMean) < 0.1 + True + >>> abs(mat.std() - expStd) < 0.1 + True + """ + return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale), + numRows, numCols, numPartitions, seed) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/sql/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 8fee92a..726d288 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -50,18 +50,6 @@ def since(version): return f return deco -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/sql/_types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py deleted file mode 100644 index 9e7e9f0..0000000 --- a/python/pyspark/sql/_types.py +++ /dev/null @@ -1,1306 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -import decimal -import time -import datetime -import keyword -import warnings -import json -import re -import weakref -from array import array -from operator import itemgetter - -if sys.version >= "3": - long = int - unicode = str - -from py4j.protocol import register_input_converter -from py4j.java_gateway import JavaClass - -__all__ = [ - "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", - "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", - "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] - - -class DataType(object): - """Base class for data types.""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def simpleString(self): - return self.typeName() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -# This singleton pattern does not work with pickle, you will get -# another object after pickle and unpickle -class DataTypeSingleton(type): - """Metaclass for DataType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(DataTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class NullType(DataType): - """Null type. - - The data type representing None, used for the types that cannot be inferred. - """ - - __metaclass__ = DataTypeSingleton - - -class AtomicType(DataType): - """An internal type used to represent everything that is not - null, UDTs, arrays, structs, and maps.""" - - __metaclass__ = DataTypeSingleton - - -class NumericType(AtomicType): - """Numeric data types. - """ - - -class IntegralType(NumericType): - """Integral data types. - """ - - -class FractionalType(NumericType): - """Fractional data types. - """ - - -class StringType(AtomicType): - """String data type. - """ - - -class BinaryType(AtomicType): - """Binary (byte array) data type. - """ - - -class BooleanType(AtomicType): - """Boolean data type. - """ - - -class DateType(AtomicType): - """Date (datetime.date) data type. - """ - - -class TimestampType(AtomicType): - """Timestamp (datetime.datetime) data type. - """ - - -class DecimalType(FractionalType): - """Decimal (decimal.Decimal) data type. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def simpleString(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal(10,0)" - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(FractionalType): - """Double data type, representing double precision floats. - """ - - -class FloatType(FractionalType): - """Float data type, representing single precision floats. - """ - - -class ByteType(IntegralType): - """Byte data type, i.e. a signed integer in a single byte. - """ - def simpleString(self): - return 'tinyint' - - -class IntegerType(IntegralType): - """Int data type, i.e. a signed 32-bit integer. - """ - def simpleString(self): - return 'int' - - -class LongType(IntegralType): - """Long data type, i.e. a signed 64-bit integer. - - If the values are beyond the range of [-9223372036854775808, 9223372036854775807], - please use :class:`DecimalType`. - """ - def simpleString(self): - return 'bigint' - - -class ShortType(IntegralType): - """Short data type, i.e. a signed 16-bit integer. - """ - def simpleString(self): - return 'smallint' - - -class ArrayType(DataType): - """Array data type. - - :param elementType: :class:`DataType` of each element in the array. - :param containsNull: boolean, whether the array can contain null (None) values. - """ - - def __init__(self, elementType, containsNull=True): - """ - >>> ArrayType(StringType()) == ArrayType(StringType(), True) - True - >>> ArrayType(StringType(), False) == ArrayType(StringType()) - False - """ - assert isinstance(elementType, DataType), "elementType should be DataType" - self.elementType = elementType - self.containsNull = containsNull - - def simpleString(self): - return 'array<%s>' % self.elementType.simpleString() - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - """Map data type. - - :param keyType: :class:`DataType` of the keys in the map. - :param valueType: :class:`DataType` of the values in the map. - :param valueContainsNull: indicates whether values can contain null (None) values. - - Keys in a map data type are not allowed to be null (None). - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """ - >>> (MapType(StringType(), IntegerType()) - ... == MapType(StringType(), IntegerType(), True)) - True - >>> (MapType(StringType(), IntegerType(), False) - ... == MapType(StringType(), FloatType())) - False - """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def simpleString(self): - return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - """A field in :class:`StructType`. - - :param name: string, name of the field. - :param dataType: :class:`DataType` of the field. - :param nullable: boolean, whether the field can be null (None) or not. - :param metadata: a dict from string to simple type that can be serialized to JSON automatically - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """ - >>> (StructField("f1", StringType(), True) - ... == StructField("f1", StringType(), True)) - True - >>> (StructField("f1", StringType(), True) - ... == StructField("f2", StringType(), True)) - False - """ - assert isinstance(dataType, DataType), "dataType should be DataType" - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def simpleString(self): - return '%s:%s' % (self.name, self.dataType.simpleString()) - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - """Struct type, consisting of a list of :class:`StructField`. - - This is the data type representing a :class:`Row`. - """ - - def __init__(self, fields): - """ - >>> struct1 = StructType([StructField("f1", StringType(), True)]) - >>> struct2 = StructType([StructField("f1", StringType(), True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType(), True)]) - >>> struct2 = StructType([StructField("f1", StringType(), True), - ... StructField("f2", IntegerType(), False)]) - >>> struct1 == struct2 - False - """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields - - def simpleString(self): - return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """User-defined type (UDT). - - .. note:: WARN: Spark Internal Use Only - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def simpleString(self): - return 'udt' - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, - ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] -_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> import pickle - >>> def check_datatype(datatype): - ... pickled = pickle.loads(pickle.dumps(datatype)) - ... assert datatype == pickled - ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... assert datatype == python_datatype - >>> for cls in _all_atomic_types.values(): - ... check_datatype(cls()) - - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - - >>> check_datatype(ExamplePointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if not isinstance(json_value, dict): - if json_value in _all_atomic_types.keys(): - return _all_atomic_types[json_value]() - elif json_value == 'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: LongType, - float: DoubleType, - str: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - -if sys.version < "3": - _type_mappings.update({ - unicode: StringType, - long: LongType, - }) - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - return NullType() - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.items(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except TypeError: - raise TypeError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, (tuple, list)): - if hasattr(row, "__fields__"): # Row - items = zip(row.__fields__, tuple(row)) - elif hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - else: - names = ['_%d' % i for i in range(1, len(row) + 1)] - items = zip(names, row) - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise TypeError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _need_converter(dataType): - if isinstance(dataType, StructType): - return True - elif isinstance(dataType, ArrayType): - return _need_converter(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) - elif isinstance(dataType, NullType): - return True - else: - return False - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if not _need_converter(dataType): - return lambda x: x - - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: [conv(v) for v in row] - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, (tuple, list)): - if convert_fields: - return tuple(conv(v) for v, conv in zip(obj, converters)) - else: - return tuple(obj) - - if isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise TypeError("Unexpected obj type: %s" % type(obj)) - - if convert_fields: - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - else: - return tuple([d.get(name) for name in names]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,NullType,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,NullType,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(NullType,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, NullType(), True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,NullType,true))),true) - """ - s = s.strip() - if not s: - return NullType() - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(NullType(), _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types inferred from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...LongType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,LongType... - """ - if isinstance(dataType, NullType): - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = next(iter(obj.items())) - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise TypeError("Unexpected dataType: %s" % type(dataType)) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date, datetime.datetime), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - -_cached_cls = weakref.WeakValueDictionary() - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None or cls.__datatype is not dataType: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - cls.__datatype = dataType - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for atomic types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __datatype = dataType - __fields__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__fields__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__fields__)) - - def __reduce__(self): - return (_restore_object, (self.__datatype, tuple(self))) - - return Row - - -def _create_row(fields, values): - row = Row(*values) - row.__fields__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - <Row(name, age)> - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - row = tuple.__new__(self, [kwargs[n] for n in names]) - row.__fields__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__fields__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__fields__, self)) - - # let object acts like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__fields__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - except ValueError: - raise AttributeError(item) - - def __reduce__(self): - """Returns a tuple so Python knows how to pickle Row.""" - if hasattr(self, "__fields__"): - return (_create_row, (self.__fields__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - """Printable representation of Row used in Python REPL.""" - if hasattr(self, "__fields__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__fields__, tuple(self))) - else: - return "<Row(%s)>" % ", ".join(self) - - -class DateConverter(object): - def can_convert(self, obj): - return isinstance(obj, datetime.date) - - def convert(self, obj, gateway_client): - Date = JavaClass("java.sql.Date", gateway_client) - return Date.valueOf(obj.strftime("%Y-%m-%d")) - - -class DatetimeConverter(object): - def can_convert(self, obj): - return isinstance(obj, datetime.datetime) - - def convert(self, obj, gateway_client): - Timestamp = JavaClass("java.sql.Timestamp", gateway_client) - return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) - - -# datetime is a subclass of date, we should register DatetimeConverter first -register_input_converter(DatetimeConverter()) -register_input_converter(DateConverter()) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org