http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
deleted file mode 100644
index ea02691..0000000
--- a/python/pyspark/sql/tests.py
+++ /dev/null
@@ -1,7079 +0,0 @@
-# -*- encoding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Unit tests for pyspark.sql; additional tests are implemented as doctests in
-individual modules.
-"""
-import os
-import sys
-import subprocess
-import pydoc
-import shutil
-import tempfile
-import threading
-import pickle
-import functools
-import time
-import datetime
-import array
-import ctypes
-import warnings
-import py4j
-from contextlib import contextmanager
-
-try:
-    import xmlrunner
-except ImportError:
-    xmlrunner = None
-
-if sys.version_info[:2] <= (2, 6):
-    try:
-        import unittest2 as unittest
-    except ImportError:
-        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
-        sys.exit(1)
-else:
-    import unittest
-
-from pyspark.util import _exception_message
-
-_pandas_requirement_message = None
-try:
-    from pyspark.sql.utils import require_minimum_pandas_version
-    require_minimum_pandas_version()
-except ImportError as e:
-    # If Pandas version requirement is not satisfied, skip related tests.
-    _pandas_requirement_message = _exception_message(e)
-
-_pyarrow_requirement_message = None
-try:
-    from pyspark.sql.utils import require_minimum_pyarrow_version
-    require_minimum_pyarrow_version()
-except ImportError as e:
-    # If Arrow version requirement is not satisfied, skip related tests.
-    _pyarrow_requirement_message = _exception_message(e)
-
-_test_not_compiled_message = None
-try:
-    from pyspark.sql.utils import require_test_compiled
-    require_test_compiled()
-except Exception as e:
-    _test_not_compiled_message = _exception_message(e)
-
-_have_pandas = _pandas_requirement_message is None
-_have_pyarrow = _pyarrow_requirement_message is None
-_test_compiled = _test_not_compiled_message is None
-
-from pyspark import SparkConf, SparkContext
-from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
-from pyspark.sql.types import *
-from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
-from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, 
_array_type_mappings
-from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
-from pyspark.sql.types import _merge_type
-from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, 
SparkSubmitTests
-from pyspark.sql.functions import UserDefinedFunction, sha2, lit
-from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, ParseException, 
IllegalArgumentException
-
-
-class UTCOffsetTimezone(datetime.tzinfo):
-    """
-    Specifies timezone in UTC offset
-    """
-
-    def __init__(self, offset=0):
-        self.ZERO = datetime.timedelta(hours=offset)
-
-    def utcoffset(self, dt):
-        return self.ZERO
-
-    def dst(self, dt):
-        return self.ZERO
-
-
-class ExamplePointUDT(UserDefinedType):
-    """
-    User-defined type (UDT) for ExamplePoint.
-    """
-
-    @classmethod
-    def sqlType(self):
-        return ArrayType(DoubleType(), False)
-
-    @classmethod
-    def module(cls):
-        return 'pyspark.sql.tests'
-
-    @classmethod
-    def scalaUDT(cls):
-        return 'org.apache.spark.sql.test.ExamplePointUDT'
-
-    def serialize(self, obj):
-        return [obj.x, obj.y]
-
-    def deserialize(self, datum):
-        return ExamplePoint(datum[0], datum[1])
-
-
-class ExamplePoint:
-    """
-    An example class to demonstrate UDT in Scala, Java, and Python.
-    """
-
-    __UDT__ = ExamplePointUDT()
-
-    def __init__(self, x, y):
-        self.x = x
-        self.y = y
-
-    def __repr__(self):
-        return "ExamplePoint(%s,%s)" % (self.x, self.y)
-
-    def __str__(self):
-        return "(%s,%s)" % (self.x, self.y)
-
-    def __eq__(self, other):
-        return isinstance(other, self.__class__) and \
-            other.x == self.x and other.y == self.y
-
-
-class PythonOnlyUDT(UserDefinedType):
-    """
-    User-defined type (UDT) for ExamplePoint.
-    """
-
-    @classmethod
-    def sqlType(self):
-        return ArrayType(DoubleType(), False)
-
-    @classmethod
-    def module(cls):
-        return '__main__'
-
-    def serialize(self, obj):
-        return [obj.x, obj.y]
-
-    def deserialize(self, datum):
-        return PythonOnlyPoint(datum[0], datum[1])
-
-    @staticmethod
-    def foo():
-        pass
-
-    @property
-    def props(self):
-        return {}
-
-
-class PythonOnlyPoint(ExamplePoint):
-    """
-    An example class to demonstrate UDT in only Python
-    """
-    __UDT__ = PythonOnlyUDT()
-
-
-class MyObject(object):
-    def __init__(self, key, value):
-        self.key = key
-        self.value = value
-
-
-class SQLTestUtils(object):
-    """
-    This util assumes the instance of this to have 'spark' attribute, having a 
spark session.
-    It is usually used with 'ReusedSQLTestCase' class but can be used if you 
feel sure the
-    the implementation of this class has 'spark' attribute.
-    """
-
-    @contextmanager
-    def sql_conf(self, pairs):
-        """
-        A convenient context manager to test some configuration specific 
logic. This sets
-        `value` to the configuration `key` and then restores it back when it 
exits.
-        """
-        assert isinstance(pairs, dict), "pairs should be a dictionary."
-        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
-
-        keys = pairs.keys()
-        new_values = pairs.values()
-        old_values = [self.spark.conf.get(key, None) for key in keys]
-        for key, new_value in zip(keys, new_values):
-            self.spark.conf.set(key, new_value)
-        try:
-            yield
-        finally:
-            for key, old_value in zip(keys, old_values):
-                if old_value is None:
-                    self.spark.conf.unset(key)
-                else:
-                    self.spark.conf.set(key, old_value)
-
-    @contextmanager
-    def database(self, *databases):
-        """
-        A convenient context manager to test with some specific databases. 
This drops the given
-        databases if exist and sets current database to "default" when it 
exits.
-        """
-        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
-
-        try:
-            yield
-        finally:
-            for db in databases:
-                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
-            self.spark.catalog.setCurrentDatabase("default")
-
-    @contextmanager
-    def table(self, *tables):
-        """
-        A convenient context manager to test with some specific tables. This 
drops the given tables
-        if exist when it exits.
-        """
-        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
-
-        try:
-            yield
-        finally:
-            for t in tables:
-                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
-
-    @contextmanager
-    def tempView(self, *views):
-        """
-        A convenient context manager to test with some specific views. This 
drops the given views
-        if exist when it exits.
-        """
-        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
-
-        try:
-            yield
-        finally:
-            for v in views:
-                self.spark.catalog.dropTempView(v)
-
-    @contextmanager
-    def function(self, *functions):
-        """
-        A convenient context manager to test with some specific functions. 
This drops the given
-        functions if exist when it exits.
-        """
-        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
-
-        try:
-            yield
-        finally:
-            for f in functions:
-                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
-
-
-class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
-    @classmethod
-    def setUpClass(cls):
-        super(ReusedSQLTestCase, cls).setUpClass()
-        cls.spark = SparkSession(cls.sc)
-
-    @classmethod
-    def tearDownClass(cls):
-        super(ReusedSQLTestCase, cls).tearDownClass()
-        cls.spark.stop()
-
-    def assertPandasEqual(self, expected, result):
-        msg = ("DataFrames are not equal: " +
-               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
-               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
-        self.assertTrue(expected.equals(result), msg=msg)
-
-
-class DataTypeTests(unittest.TestCase):
-    # regression test for SPARK-6055
-    def test_data_type_eq(self):
-        lt = LongType()
-        lt2 = pickle.loads(pickle.dumps(LongType()))
-        self.assertEqual(lt, lt2)
-
-    # regression test for SPARK-7978
-    def test_decimal_type(self):
-        t1 = DecimalType()
-        t2 = DecimalType(10, 2)
-        self.assertTrue(t2 is not t1)
-        self.assertNotEqual(t1, t2)
-        t3 = DecimalType(8)
-        self.assertNotEqual(t2, t3)
-
-    # regression test for SPARK-10392
-    def test_datetype_equal_zero(self):
-        dt = DateType()
-        self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
-
-    # regression test for SPARK-17035
-    def test_timestamp_microsecond(self):
-        tst = TimestampType()
-        self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 
999999)
-
-    def test_empty_row(self):
-        row = Row()
-        self.assertEqual(len(row), 0)
-
-    def test_struct_field_type_name(self):
-        struct_field = StructField("a", IntegerType())
-        self.assertRaises(TypeError, struct_field.typeName)
-
-    def test_invalid_create_row(self):
-        row_class = Row("c1", "c2")
-        self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
-
-
-class SparkSessionBuilderTests(unittest.TestCase):
-
-    def test_create_spark_context_first_then_spark_session(self):
-        sc = None
-        session = None
-        try:
-            conf = SparkConf().set("key1", "value1")
-            sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
-            session = SparkSession.builder.config("key2", 
"value2").getOrCreate()
-
-            self.assertEqual(session.conf.get("key1"), "value1")
-            self.assertEqual(session.conf.get("key2"), "value2")
-            self.assertEqual(session.sparkContext, sc)
-
-            self.assertFalse(sc.getConf().contains("key2"))
-            self.assertEqual(sc.getConf().get("key1"), "value1")
-        finally:
-            if session is not None:
-                session.stop()
-            if sc is not None:
-                sc.stop()
-
-    def test_another_spark_session(self):
-        session1 = None
-        session2 = None
-        try:
-            session1 = SparkSession.builder.config("key1", 
"value1").getOrCreate()
-            session2 = SparkSession.builder.config("key2", 
"value2").getOrCreate()
-
-            self.assertEqual(session1.conf.get("key1"), "value1")
-            self.assertEqual(session2.conf.get("key1"), "value1")
-            self.assertEqual(session1.conf.get("key2"), "value2")
-            self.assertEqual(session2.conf.get("key2"), "value2")
-            self.assertEqual(session1.sparkContext, session2.sparkContext)
-
-            self.assertEqual(session1.sparkContext.getConf().get("key1"), 
"value1")
-            self.assertFalse(session1.sparkContext.getConf().contains("key2"))
-        finally:
-            if session1 is not None:
-                session1.stop()
-            if session2 is not None:
-                session2.stop()
-
-
-class SQLTests(ReusedSQLTestCase):
-
-    @classmethod
-    def setUpClass(cls):
-        ReusedSQLTestCase.setUpClass()
-        cls.spark.catalog._reset()
-        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        os.unlink(cls.tempdir.name)
-        cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
-        cls.df = cls.spark.createDataFrame(cls.testData)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedSQLTestCase.tearDownClass()
-        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-
-    def test_sqlcontext_reuses_sparksession(self):
-        sqlContext1 = SQLContext(self.sc)
-        sqlContext2 = SQLContext(self.sc)
-        self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
-
-    def test_row_should_be_read_only(self):
-        row = Row(a=1, b=2)
-        self.assertEqual(1, row.a)
-
-        def foo():
-            row.a = 3
-        self.assertRaises(Exception, foo)
-
-        row2 = self.spark.range(10).first()
-        self.assertEqual(0, row2.id)
-
-        def foo2():
-            row2.id = 2
-        self.assertRaises(Exception, foo2)
-
-    def test_range(self):
-        self.assertEqual(self.spark.range(1, 1).count(), 0)
-        self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
-        self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
-        self.assertEqual(self.spark.range(-2).count(), 0)
-        self.assertEqual(self.spark.range(3).count(), 3)
-
-    def test_duplicated_column_names(self):
-        df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
-        row = df.select('*').first()
-        self.assertEqual(1, row[0])
-        self.assertEqual(2, row[1])
-        self.assertEqual("Row(c=1, c=2)", str(row))
-        # Cannot access columns
-        self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
-        self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
-        self.assertRaises(AnalysisException, lambda: 
df.select(df["c"]).first())
-
-    def test_column_name_encoding(self):
-        """Ensure that created columns has `str` type consistently."""
-        columns = self.spark.createDataFrame([('Alice', 1)], ['name', 
u'age']).columns
-        self.assertEqual(columns, ['name', 'age'])
-        self.assertTrue(isinstance(columns[0], str))
-        self.assertTrue(isinstance(columns[1], str))
-
-    def test_explode(self):
-        from pyspark.sql.functions import explode, explode_outer, 
posexplode_outer
-        d = [
-            Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
-            Row(a=1, intlist=[], mapfield={}),
-            Row(a=1, intlist=None, mapfield=None),
-        ]
-        rdd = self.sc.parallelize(d)
-        data = self.spark.createDataFrame(rdd)
-
-        result = 
data.select(explode(data.intlist).alias("a")).select("a").collect()
-        self.assertEqual(result[0][0], 1)
-        self.assertEqual(result[1][0], 2)
-        self.assertEqual(result[2][0], 3)
-
-        result = data.select(explode(data.mapfield).alias("a", 
"b")).select("a", "b").collect()
-        self.assertEqual(result[0][0], "a")
-        self.assertEqual(result[0][1], "b")
-
-        result = [tuple(x) for x in 
data.select(posexplode_outer("intlist")).collect()]
-        self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, 
None)])
-
-        result = [tuple(x) for x in 
data.select(posexplode_outer("mapfield")).collect()]
-        self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, 
None, None)])
-
-        result = [x[0] for x in 
data.select(explode_outer("intlist")).collect()]
-        self.assertEqual(result, [1, 2, 3, None, None])
-
-        result = [tuple(x) for x in 
data.select(explode_outer("mapfield")).collect()]
-        self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
-
-    def test_and_in_expression(self):
-        self.assertEqual(4, self.df.filter((self.df.key <= 10) & 
(self.df.value <= "2")).count())
-        self.assertRaises(ValueError, lambda: (self.df.key <= 10) and 
(self.df.value <= "2"))
-        self.assertEqual(14, self.df.filter((self.df.key <= 3) | 
(self.df.value < "2")).count())
-        self.assertRaises(ValueError, lambda: self.df.key <= 3 or 
self.df.value < "2")
-        self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
-        self.assertRaises(ValueError, lambda: not self.df.key == 1)
-
-    def test_udf_with_callable(self):
-        d = [Row(number=i, squared=i**2) for i in range(10)]
-        rdd = self.sc.parallelize(d)
-        data = self.spark.createDataFrame(rdd)
-
-        class PlusFour:
-            def __call__(self, col):
-                if col is not None:
-                    return col + 4
-
-        call = PlusFour()
-        pudf = UserDefinedFunction(call, LongType())
-        res = data.select(pudf(data['number']).alias('plus_four'))
-        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
-
-    def test_udf_with_partial_function(self):
-        d = [Row(number=i, squared=i**2) for i in range(10)]
-        rdd = self.sc.parallelize(d)
-        data = self.spark.createDataFrame(rdd)
-
-        def some_func(col, param):
-            if col is not None:
-                return col + param
-
-        pfunc = functools.partial(some_func, param=4)
-        pudf = UserDefinedFunction(pfunc, LongType())
-        res = data.select(pudf(data['number']).alias('plus_four'))
-        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
-
-    def test_udf(self):
-        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + 
y, IntegerType())
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], 5)
-
-        # This is to check if a deprecated 'SQLContext.registerFunction' can 
call its alias.
-        sqlContext = self.spark._wrapped
-        sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
-        [row] = sqlContext.sql("SELECT oneArg('test')").collect()
-        self.assertEqual(row[0], 4)
-
-    def test_udf2(self):
-        with self.tempView("test"):
-            self.spark.catalog.registerFunction("strlen", lambda string: 
len(string), IntegerType())
-            self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
-                .createOrReplaceTempView("test")
-            [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) 
> 1").collect()
-            self.assertEqual(4, res[0])
-
-    def test_udf3(self):
-        two_args = self.spark.catalog.registerFunction(
-            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
-        self.assertEqual(two_args.deterministic, True)
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], u'5')
-
-    def test_udf_registration_return_type_none(self):
-        two_args = self.spark.catalog.registerFunction(
-            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, 
"integer"), None)
-        self.assertEqual(two_args.deterministic, True)
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], 5)
-
-    def test_udf_registration_return_type_not_none(self):
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
-                self.spark.catalog.registerFunction(
-                    "f", UserDefinedFunction(lambda x, y: len(x) + y, 
StringType()), StringType())
-
-    def test_nondeterministic_udf(self):
-        # Test that nondeterministic UDFs are evaluated only once in chained 
UDF evaluations
-        from pyspark.sql.functions import udf
-        import random
-        udf_random_col = udf(lambda: int(100 * random.random()), 
IntegerType()).asNondeterministic()
-        self.assertEqual(udf_random_col.deterministic, False)
-        df = 
self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
-        udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
-        [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
-        self.assertEqual(row[0] + 10, row[1])
-
-    def test_nondeterministic_udf2(self):
-        import random
-        from pyspark.sql.functions import udf
-        random_udf = udf(lambda: random.randint(6, 6), 
IntegerType()).asNondeterministic()
-        self.assertEqual(random_udf.deterministic, False)
-        random_udf1 = self.spark.catalog.registerFunction("randInt", 
random_udf)
-        self.assertEqual(random_udf1.deterministic, False)
-        [row] = self.spark.sql("SELECT randInt()").collect()
-        self.assertEqual(row[0], 6)
-        [row] = self.spark.range(1).select(random_udf1()).collect()
-        self.assertEqual(row[0], 6)
-        [row] = self.spark.range(1).select(random_udf()).collect()
-        self.assertEqual(row[0], 6)
-        # render_doc() reproduces the help() exception without printing output
-        pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
-        pydoc.render_doc(random_udf)
-        pydoc.render_doc(random_udf1)
-        pydoc.render_doc(udf(lambda x: x).asNondeterministic)
-
-    def test_nondeterministic_udf3(self):
-        # regression test for SPARK-23233
-        from pyspark.sql.functions import udf
-        f = udf(lambda x: x)
-        # Here we cache the JVM UDF instance.
-        self.spark.range(1).select(f("id"))
-        # This should reset the cache to set the deterministic status 
correctly.
-        f = f.asNondeterministic()
-        # Check the deterministic status of udf.
-        df = self.spark.range(1).select(f("id"))
-        deterministic = 
df._jdf.logicalPlan().projectList().head().deterministic()
-        self.assertFalse(deterministic)
-
-    def test_nondeterministic_udf_in_aggregate(self):
-        from pyspark.sql.functions import udf, sum
-        import random
-        udf_random_col = udf(lambda: int(100 * random.random()), 
'int').asNondeterministic()
-        df = self.spark.range(10)
-
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(AnalysisException, 
"nondeterministic"):
-                df.groupby('id').agg(sum(udf_random_col())).collect()
-            with self.assertRaisesRegexp(AnalysisException, 
"nondeterministic"):
-                df.agg(sum(udf_random_col())).collect()
-
-    def test_chained_udf(self):
-        self.spark.catalog.registerFunction("double", lambda x: x + x, 
IntegerType())
-        [row] = self.spark.sql("SELECT double(1)").collect()
-        self.assertEqual(row[0], 2)
-        [row] = self.spark.sql("SELECT double(double(1))").collect()
-        self.assertEqual(row[0], 4)
-        [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
-        self.assertEqual(row[0], 6)
-
-    def test_single_udf_with_repeated_argument(self):
-        # regression test for SPARK-20685
-        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
-        row = self.spark.sql("SELECT add(1, 1)").first()
-        self.assertEqual(tuple(row), (2, ))
-
-    def test_multiple_udfs(self):
-        self.spark.catalog.registerFunction("double", lambda x: x * 2, 
IntegerType())
-        [row] = self.spark.sql("SELECT double(1), double(2)").collect()
-        self.assertEqual(tuple(row), (2, 4))
-        [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 
2)").collect()
-        self.assertEqual(tuple(row), (4, 12))
-        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
-        [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 
1)").collect()
-        self.assertEqual(tuple(row), (6, 5))
-
-    def test_udf_in_filter_on_top_of_outer_join(self):
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1)])
-        right = self.spark.createDataFrame([Row(a=1)])
-        df = left.join(right, on='a', how='left_outer')
-        df = df.withColumn('b', udf(lambda x: 'x')(df.a))
-        self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
-
-    def test_udf_in_filter_on_top_of_join(self):
-        # regression test for SPARK-18589
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1)])
-        right = self.spark.createDataFrame([Row(b=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-        df = left.crossJoin(right).filter(f("a", "b"))
-        self.assertEqual(df.collect(), [Row(a=1, b=1)])
-
-    def test_udf_in_join_condition(self):
-        # regression test for SPARK-25314
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1)])
-        right = self.spark.createDataFrame([Row(b=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-        df = left.join(right, f("a", "b"))
-        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit 
cartesian product'):
-            df.collect()
-        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
-            self.assertEqual(df.collect(), [Row(a=1, b=1)])
-
-    def test_udf_in_left_semi_join_condition(self):
-        # regression test for SPARK-25314
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
-        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-        df = left.join(right, f("a", "b"), "leftsemi")
-        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit 
cartesian product'):
-            df.collect()
-        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
-            self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
-
-    def test_udf_and_common_filter_in_join_condition(self):
-        # regression test for SPARK-25314
-        # test the complex scenario with both udf and common filter
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
-        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-        df = left.join(right, [f("a", "b"), left.a1 == right.b1])
-        # do not need spark.sql.crossJoin.enabled=true for udf is not the only 
join condition.
-        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
-
-    def test_udf_and_common_filter_in_left_semi_join_condition(self):
-        # regression test for SPARK-25314
-        # test the complex scenario with both udf and common filter
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
-        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-        df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
-        # do not need spark.sql.crossJoin.enabled=true for udf is not the only 
join condition.
-        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
-
-    def test_udf_not_supported_in_join_condition(self):
-        # regression test for SPARK-25314
-        # test python udf is not supported in join type besides left_semi and 
inner join.
-        from pyspark.sql.functions import udf
-        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
-        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
-        f = udf(lambda a, b: a == b, BooleanType())
-
-        def runWithJoinType(join_type, type_string):
-            with self.assertRaisesRegexp(
-                    AnalysisException,
-                    'Using PythonUDF.*%s is not supported.' % type_string):
-                left.join(right, [f("a", "b"), left.a1 == right.b1], 
join_type).collect()
-        runWithJoinType("full", "FullOuter")
-        runWithJoinType("left", "LeftOuter")
-        runWithJoinType("right", "RightOuter")
-        runWithJoinType("leftanti", "LeftAnti")
-
-    def test_udf_without_arguments(self):
-        self.spark.catalog.registerFunction("foo", lambda: "bar")
-        [row] = self.spark.sql("SELECT foo()").collect()
-        self.assertEqual(row[0], "bar")
-
-    def test_udf_with_array_type(self):
-        with self.tempView("test"):
-            d = [Row(l=list(range(3)), d={"key": list(range(5))})]
-            rdd = self.sc.parallelize(d)
-            self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
-            self.spark.catalog.registerFunction(
-                "copylist", lambda l: list(l), ArrayType(IntegerType()))
-            self.spark.catalog.registerFunction("maplen", lambda d: len(d), 
IntegerType())
-            [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from 
test").collect()
-            self.assertEqual(list(range(3)), l1)
-            self.assertEqual(1, l2)
-
-    def test_broadcast_in_udf(self):
-        bar = {"a": "aa", "b": "bb", "c": "abc"}
-        foo = self.sc.broadcast(bar)
-        self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if 
x else '')
-        [res] = self.spark.sql("SELECT MYUDF('c')").collect()
-        self.assertEqual("abc", res[0])
-        [res] = self.spark.sql("SELECT MYUDF('')").collect()
-        self.assertEqual("", res[0])
-
-    def test_udf_with_filter_function(self):
-        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col
-        from pyspark.sql.types import BooleanType
-
-        my_filter = udf(lambda a: a < 2, BooleanType())
-        sel = df.select(col("key"), 
col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
-        self.assertEqual(sel.collect(), [Row(key=1, value='1')])
-
-    def test_udf_with_aggregate_function(self):
-        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col, sum
-        from pyspark.sql.types import BooleanType
-
-        my_filter = udf(lambda a: a == 1, BooleanType())
-        sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
-        self.assertEqual(sel.collect(), [Row(key=1)])
-
-        my_copy = udf(lambda x: x, IntegerType())
-        my_add = udf(lambda a, b: int(a + b), IntegerType())
-        my_strlen = udf(lambda x: len(x), IntegerType())
-        sel = df.groupBy(my_copy(col("key")).alias("k"))\
-            .agg(sum(my_strlen(col("value"))).alias("s"))\
-            .select(my_add(col("k"), col("s")).alias("t"))
-        self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
-
-    def test_udf_in_generate(self):
-        from pyspark.sql.functions import udf, explode
-        df = self.spark.range(5)
-        f = udf(lambda x: list(range(x)), ArrayType(LongType()))
-        row = df.select(explode(f(*df))).groupBy().sum().first()
-        self.assertEqual(row[0], 10)
-
-        df = self.spark.range(3)
-        res = df.select("id", explode(f(df.id))).collect()
-        self.assertEqual(res[0][0], 1)
-        self.assertEqual(res[0][1], 0)
-        self.assertEqual(res[1][0], 2)
-        self.assertEqual(res[1][1], 0)
-        self.assertEqual(res[2][0], 2)
-        self.assertEqual(res[2][1], 1)
-
-        range_udf = udf(lambda value: list(range(value - 1, value + 1)), 
ArrayType(IntegerType()))
-        res = df.select("id", explode(range_udf(df.id))).collect()
-        self.assertEqual(res[0][0], 0)
-        self.assertEqual(res[0][1], -1)
-        self.assertEqual(res[1][0], 0)
-        self.assertEqual(res[1][1], 0)
-        self.assertEqual(res[2][0], 1)
-        self.assertEqual(res[2][1], 0)
-        self.assertEqual(res[3][0], 1)
-        self.assertEqual(res[3][1], 1)
-
-    def test_udf_with_order_by_and_limit(self):
-        from pyspark.sql.functions import udf
-        my_copy = udf(lambda x: x, IntegerType())
-        df = self.spark.range(10).orderBy("id")
-        res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
-        res.explain(True)
-        self.assertEqual(res.collect(), [Row(id=0, copy=0)])
-
-    def test_udf_registration_returns_udf(self):
-        df = self.spark.range(10)
-        add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
-
-        self.assertListEqual(
-            df.selectExpr("add_three(id) AS plus_three").collect(),
-            df.select(add_three("id").alias("plus_three")).collect()
-        )
-
-        # This is to check if a 'SQLContext.udf' can call its alias.
-        sqlContext = self.spark._wrapped
-        add_four = sqlContext.udf.register("add_four", lambda x: x + 4, 
IntegerType())
-
-        self.assertListEqual(
-            df.selectExpr("add_four(id) AS plus_four").collect(),
-            df.select(add_four("id").alias("plus_four")).collect()
-        )
-
-    def test_non_existed_udf(self):
-        spark = self.spark
-        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udf",
-                                lambda: spark.udf.registerJavaFunction("udf1", 
"non_existed_udf"))
-
-        # This is to check if a deprecated 'SQLContext.registerJavaFunction' 
can call its alias.
-        sqlContext = spark._wrapped
-        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udf",
-                                lambda: 
sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
-
-    def test_non_existed_udaf(self):
-        spark = self.spark
-        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udaf",
-                                lambda: spark.udf.registerJavaUDAF("udaf1", 
"non_existed_udaf"))
-
-    def test_linesep_text(self):
-        df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", 
lineSep=",")
-        expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
-                    Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
-                    Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
-                    Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
-        self.assertEqual(df.collect(), expected)
-
-        tpath = tempfile.mkdtemp()
-        shutil.rmtree(tpath)
-        try:
-            df.write.text(tpath, lineSep="!")
-            expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
-                        Row(value=u'Tom!30!"My name is Tom"'),
-                        Row(value=u'Hyukjin!25!"I am Hyukjin'),
-                        Row(value=u''), Row(value=u'I love Spark!"'),
-                        Row(value=u'!')]
-            readback = self.spark.read.text(tpath)
-            self.assertEqual(readback.collect(), expected)
-        finally:
-            shutil.rmtree(tpath)
-
-    def test_multiline_json(self):
-        people1 = self.spark.read.json("python/test_support/sql/people.json")
-        people_array = 
self.spark.read.json("python/test_support/sql/people_array.json",
-                                            multiLine=True)
-        self.assertEqual(people1.collect(), people_array.collect())
-
-    def test_encoding_json(self):
-        people_array = self.spark.read\
-            .json("python/test_support/sql/people_array_utf16le.json",
-                  multiLine=True, encoding="UTF-16LE")
-        expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
-        self.assertEqual(people_array.collect(), expected)
-
-    def test_linesep_json(self):
-        df = self.spark.read.json("python/test_support/sql/people.json", 
lineSep=",")
-        expected = [Row(_corrupt_record=None, name=u'Michael'),
-                    Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', 
name=None),
-                    Row(_corrupt_record=u' "age":19}\n', name=None)]
-        self.assertEqual(df.collect(), expected)
-
-        tpath = tempfile.mkdtemp()
-        shutil.rmtree(tpath)
-        try:
-            df = self.spark.read.json("python/test_support/sql/people.json")
-            df.write.json(tpath, lineSep="!!")
-            readback = self.spark.read.json(tpath, lineSep="!!")
-            self.assertEqual(readback.collect(), df.collect())
-        finally:
-            shutil.rmtree(tpath)
-
-    def test_multiline_csv(self):
-        ages_newlines = self.spark.read.csv(
-            "python/test_support/sql/ages_newlines.csv", multiLine=True)
-        expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
-                    Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
-                    Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI 
love Spark!')]
-        self.assertEqual(ages_newlines.collect(), expected)
-
-    def test_ignorewhitespace_csv(self):
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        self.spark.createDataFrame([[" a", "b  ", " c "]]).write.csv(
-            tmpPath,
-            ignoreLeadingWhiteSpace=False,
-            ignoreTrailingWhiteSpace=False)
-
-        expected = [Row(value=u' a,b  , c ')]
-        readback = self.spark.read.text(tmpPath)
-        self.assertEqual(readback.collect(), expected)
-        shutil.rmtree(tmpPath)
-
-    def test_read_multiple_orc_file(self):
-        df = 
self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
-                                  
"python/test_support/sql/orc_partitioned/b=1/c=1"])
-        self.assertEqual(2, df.count())
-
-    def test_udf_with_input_file_name(self):
-        from pyspark.sql.functions import udf, input_file_name
-        sourceFile = udf(lambda path: path, StringType())
-        filePath = "python/test_support/sql/people1.json"
-        row = 
self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
-        self.assertTrue(row[0].find("people1.json") != -1)
-
-    def test_udf_with_input_file_name_for_hadooprdd(self):
-        from pyspark.sql.functions import udf, input_file_name
-
-        def filename(path):
-            return path
-
-        sameText = udf(filename, StringType())
-
-        rdd = self.sc.textFile('python/test_support/sql/people.json')
-        df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
-        row = df.select(sameText(df['file'])).first()
-        self.assertTrue(row[0].find("people.json") != -1)
-
-        rdd2 = self.sc.newAPIHadoopFile(
-            'python/test_support/sql/people.json',
-            'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
-            'org.apache.hadoop.io.LongWritable',
-            'org.apache.hadoop.io.Text')
-
-        df2 = 
self.spark.read.json(rdd2).select(input_file_name().alias('file'))
-        row2 = df2.select(sameText(df2['file'])).first()
-        self.assertTrue(row2[0].find("people.json") != -1)
-
-    def test_udf_defers_judf_initialization(self):
-        # This is separate of  UDFInitializationTests
-        # to avoid context initialization
-        # when udf is called
-
-        from pyspark.sql.functions import UserDefinedFunction
-
-        f = UserDefinedFunction(lambda x: x, StringType())
-
-        self.assertIsNone(
-            f._judf_placeholder,
-            "judf should not be initialized before the first call."
-        )
-
-        self.assertIsInstance(f("foo"), Column, "UDF call should return a 
Column.")
-
-        self.assertIsNotNone(
-            f._judf_placeholder,
-            "judf should be initialized after UDF has been called."
-        )
-
-    def test_udf_with_string_return_type(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
-        add_one = UserDefinedFunction(lambda x: x + 1, "integer")
-        make_pair = UserDefinedFunction(lambda x: (-x, x), 
"struct<x:integer,y:integer>")
-        make_array = UserDefinedFunction(
-            lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
-
-        expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
-        actual = (self.spark.range(1, 2).toDF("x")
-                  .select(add_one("x"), make_pair("x"), make_array("x"))
-                  .first())
-
-        self.assertTupleEqual(expected, actual)
-
-    def test_udf_shouldnt_accept_noncallable_object(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
-        non_callable = None
-        self.assertRaises(TypeError, UserDefinedFunction, non_callable, 
StringType())
-
-    def test_udf_with_decorator(self):
-        from pyspark.sql.functions import lit, udf
-        from pyspark.sql.types import IntegerType, DoubleType
-
-        @udf(IntegerType())
-        def add_one(x):
-            if x is not None:
-                return x + 1
-
-        @udf(returnType=DoubleType())
-        def add_two(x):
-            if x is not None:
-                return float(x + 2)
-
-        @udf
-        def to_upper(x):
-            if x is not None:
-                return x.upper()
-
-        @udf()
-        def to_lower(x):
-            if x is not None:
-                return x.lower()
-
-        @udf
-        def substr(x, start, end):
-            if x is not None:
-                return x[start:end]
-
-        @udf("long")
-        def trunc(x):
-            return int(x)
-
-        @udf(returnType="double")
-        def as_double(x):
-            return float(x)
-
-        df = (
-            self.spark
-                .createDataFrame(
-                    [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", 
"float"))
-                .select(
-                    add_one("one"), add_two("one"),
-                    to_upper("Foo"), to_lower("Foo"),
-                    substr("foobar", lit(0), lit(3)),
-                    trunc("float"), as_double("one")))
-
-        self.assertListEqual(
-            [tpe for _, tpe in df.dtypes],
-            ["int", "double", "string", "string", "string", "bigint", "double"]
-        )
-
-        self.assertListEqual(
-            list(df.first()),
-            [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
-        )
-
-    def test_udf_wrapper(self):
-        from pyspark.sql.functions import udf
-        from pyspark.sql.types import IntegerType
-
-        def f(x):
-            """Identity"""
-            return x
-
-        return_type = IntegerType()
-        f_ = udf(f, return_type)
-
-        self.assertTrue(f.__doc__ in f_.__doc__)
-        self.assertEqual(f, f_.func)
-        self.assertEqual(return_type, f_.returnType)
-
-        class F(object):
-            """Identity"""
-            def __call__(self, x):
-                return x
-
-        f = F()
-        return_type = IntegerType()
-        f_ = udf(f, return_type)
-
-        self.assertTrue(f.__doc__ in f_.__doc__)
-        self.assertEqual(f, f_.func)
-        self.assertEqual(return_type, f_.returnType)
-
-        f = functools.partial(f, x=1)
-        return_type = IntegerType()
-        f_ = udf(f, return_type)
-
-        self.assertTrue(f.__doc__ in f_.__doc__)
-        self.assertEqual(f, f_.func)
-        self.assertEqual(return_type, f_.returnType)
-
-    def test_validate_column_types(self):
-        from pyspark.sql.functions import udf, to_json
-        from pyspark.sql.column import _to_java_column
-
-        self.assertTrue("Column" in _to_java_column("a").getClass().toString())
-        self.assertTrue("Column" in 
_to_java_column(u"a").getClass().toString())
-        self.assertTrue("Column" in 
_to_java_column(self.spark.range(1).id).getClass().toString())
-
-        self.assertRaisesRegexp(
-            TypeError,
-            "Invalid argument, not a string or column",
-            lambda: _to_java_column(1))
-
-        class A():
-            pass
-
-        self.assertRaises(TypeError, lambda: _to_java_column(A()))
-        self.assertRaises(TypeError, lambda: _to_java_column([]))
-
-        self.assertRaisesRegexp(
-            TypeError,
-            "Invalid argument, not a string or column",
-            lambda: udf(lambda x: x)(None))
-        self.assertRaises(TypeError, lambda: to_json(1))
-
-    def test_basic_functions(self):
-        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        df = self.spark.read.json(rdd)
-        df.count()
-        df.collect()
-        df.schema
-
-        # cache and checkpoint
-        self.assertFalse(df.is_cached)
-        df.persist()
-        df.unpersist(True)
-        df.cache()
-        self.assertTrue(df.is_cached)
-        self.assertEqual(2, df.count())
-
-        with self.tempView("temp"):
-            df.createOrReplaceTempView("temp")
-            df = self.spark.sql("select foo from temp")
-            df.count()
-            df.collect()
-
-    def test_apply_schema_to_row(self):
-        df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
-        df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
-        self.assertEqual(df.collect(), df2.collect())
-
-        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
-        df3 = self.spark.createDataFrame(rdd, df.schema)
-        self.assertEqual(10, df3.count())
-
-    def test_infer_schema_to_local(self):
-        input = [{"a": 1}, {"b": "coffee"}]
-        rdd = self.sc.parallelize(input)
-        df = self.spark.createDataFrame(input)
-        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
-        self.assertEqual(df.schema, df2.schema)
-
-        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
-        df3 = self.spark.createDataFrame(rdd, df.schema)
-        self.assertEqual(10, df3.count())
-
-    def test_apply_schema_to_dict_and_rows(self):
-        schema = StructType().add("b", StringType()).add("a", IntegerType())
-        input = [{"a": 1}, {"b": "coffee"}]
-        rdd = self.sc.parallelize(input)
-        for verify in [False, True]:
-            df = self.spark.createDataFrame(input, schema, verifySchema=verify)
-            df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
-            self.assertEqual(df.schema, df2.schema)
-
-            rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, 
b=None))
-            df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
-            self.assertEqual(10, df3.count())
-            input = [Row(a=x, b=str(x)) for x in range(10)]
-            df4 = self.spark.createDataFrame(input, schema, 
verifySchema=verify)
-            self.assertEqual(10, df4.count())
-
-    def test_create_dataframe_schema_mismatch(self):
-        input = [Row(a=1)]
-        rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
-        schema = StructType([StructField("a", IntegerType()), StructField("b", 
StringType())])
-        df = self.spark.createDataFrame(rdd, schema)
-        self.assertRaises(Exception, lambda: df.show())
-
-    def test_serialize_nested_array_and_map(self):
-        d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
-        rdd = self.sc.parallelize(d)
-        df = self.spark.createDataFrame(rdd)
-        row = df.head()
-        self.assertEqual(1, len(row.l))
-        self.assertEqual(1, row.l[0].a)
-        self.assertEqual("2", row.d["key"].d)
-
-        l = df.rdd.map(lambda x: x.l).first()
-        self.assertEqual(1, len(l))
-        self.assertEqual('s', l[0].b)
-
-        d = df.rdd.map(lambda x: x.d).first()
-        self.assertEqual(1, len(d))
-        self.assertEqual(1.0, d["key"].c)
-
-        row = df.rdd.map(lambda x: x.d["key"]).first()
-        self.assertEqual(1.0, row.c)
-        self.assertEqual("2", row.d)
-
-    def test_infer_schema(self):
-        d = [Row(l=[], d={}, s=None),
-             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
-        rdd = self.sc.parallelize(d)
-        df = self.spark.createDataFrame(rdd)
-        self.assertEqual([], df.rdd.map(lambda r: r.l).first())
-        self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
-
-        with self.tempView("test"):
-            df.createOrReplaceTempView("test")
-            result = self.spark.sql("SELECT l[0].a from test where d['key'].d 
= '2'")
-            self.assertEqual(1, result.head()[0])
-
-        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
-        self.assertEqual(df.schema, df2.schema)
-        self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
-        self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
-
-        with self.tempView("test2"):
-            df2.createOrReplaceTempView("test2")
-            result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d 
= '2'")
-            self.assertEqual(1, result.head()[0])
-
-    def test_infer_schema_specification(self):
-        from decimal import Decimal
-
-        class A(object):
-            def __init__(self):
-                self.a = 1
-
-        data = [
-            True,
-            1,
-            "a",
-            u"a",
-            datetime.date(1970, 1, 1),
-            datetime.datetime(1970, 1, 1, 0, 0),
-            1.0,
-            array.array("d", [1]),
-            [1],
-            (1, ),
-            {"a": 1},
-            bytearray(1),
-            Decimal(1),
-            Row(a=1),
-            Row("a")(1),
-            A(),
-        ]
-
-        df = self.spark.createDataFrame([data])
-        actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
-        expected = [
-            'boolean',
-            'bigint',
-            'string',
-            'string',
-            'date',
-            'timestamp',
-            'double',
-            'array<double>',
-            'array<bigint>',
-            'struct<_1:bigint>',
-            'map<string,bigint>',
-            'binary',
-            'decimal(38,18)',
-            'struct<a:bigint>',
-            'struct<a:bigint>',
-            'struct<a:bigint>',
-        ]
-        self.assertEqual(actual, expected)
-
-        actual = list(df.first())
-        expected = [
-            True,
-            1,
-            'a',
-            u"a",
-            datetime.date(1970, 1, 1),
-            datetime.datetime(1970, 1, 1, 0, 0),
-            1.0,
-            [1.0],
-            [1],
-            Row(_1=1),
-            {"a": 1},
-            bytearray(b'\x00'),
-            Decimal('1.000000000000000000'),
-            Row(a=1),
-            Row(a=1),
-            Row(a=1),
-        ]
-        self.assertEqual(actual, expected)
-
-    def test_infer_schema_not_enough_names(self):
-        df = self.spark.createDataFrame([["a", "b"]], ["col1"])
-        self.assertEqual(df.columns, ['col1', '_2'])
-
-    def test_infer_schema_fails(self):
-        with self.assertRaisesRegexp(TypeError, 'field a'):
-            
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 
1]]),
-                                       schema=["a", "b"], samplingRatio=0.99)
-
-    def test_infer_nested_schema(self):
-        NestedRow = Row("f1", "f2")
-        nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
-                                          NestedRow([2, 3], {"row2": 2.0})])
-        df = self.spark.createDataFrame(nestedRdd1)
-        self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
-
-        nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
-                                          NestedRow([[2, 3], [3, 4]], [2, 3])])
-        df = self.spark.createDataFrame(nestedRdd2)
-        self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
-
-        from collections import namedtuple
-        CustomRow = namedtuple('CustomRow', 'field1 field2')
-        rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
-                                   CustomRow(field1=2, field2="row2"),
-                                   CustomRow(field1=3, field2="row3")])
-        df = self.spark.createDataFrame(rdd)
-        self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
-
-    def test_create_dataframe_from_dict_respects_schema(self):
-        df = self.spark.createDataFrame([{'a': 1}], ["b"])
-        self.assertEqual(df.columns, ['b'])
-
-    def test_create_dataframe_from_objects(self):
-        data = [MyObject(1, "1"), MyObject(2, "2")]
-        df = self.spark.createDataFrame(data)
-        self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
-        self.assertEqual(df.first(), Row(key=1, value="1"))
-
-    def test_select_null_literal(self):
-        df = self.spark.sql("select null as col")
-        self.assertEqual(Row(col=None), df.first())
-
-    def test_apply_schema(self):
-        from datetime import date, datetime
-        rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
-                                    date(2010, 1, 1), datetime(2010, 1, 1, 1, 
1, 1),
-                                    {"a": 1}, (2,), [1, 2, 3], None)])
-        schema = StructType([
-            StructField("byte1", ByteType(), False),
-            StructField("byte2", ByteType(), False),
-            StructField("short1", ShortType(), False),
-            StructField("short2", ShortType(), False),
-            StructField("int1", IntegerType(), False),
-            StructField("float1", FloatType(), False),
-            StructField("date1", DateType(), False),
-            StructField("time1", TimestampType(), False),
-            StructField("map1", MapType(StringType(), IntegerType(), False), 
False),
-            StructField("struct1", StructType([StructField("b", ShortType(), 
False)]), False),
-            StructField("list1", ArrayType(ByteType(), False), False),
-            StructField("null1", DoubleType(), True)])
-        df = self.spark.createDataFrame(rdd, schema)
-        results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, 
x.int1, x.float1,
-                             x.date1, x.time1, x.map1["a"], x.struct1.b, 
x.list1, x.null1))
-        r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
-             datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-        self.assertEqual(r, results.first())
-
-        with self.tempView("table2"):
-            df.createOrReplaceTempView("table2")
-            r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, 
" +
-                               "short1 + 1 AS short1, short2 - 1 AS short2, 
int1 - 1 AS int1, " +
-                               "float1 + 1.5 as float1 FROM table2").first()
-
-            self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), 
tuple(r))
-
-    def test_struct_in_map(self):
-        d = [Row(m={Row(i=1): Row(s="")})]
-        df = self.sc.parallelize(d).toDF()
-        k, v = list(df.head().m.items())[0]
-        self.assertEqual(1, k.i)
-        self.assertEqual("", v.s)
-
-    def test_convert_row_to_dict(self):
-        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
-        self.assertEqual(1, row.asDict()['l'][0].a)
-        df = self.sc.parallelize([row]).toDF()
-
-        with self.tempView("test"):
-            df.createOrReplaceTempView("test")
-            row = self.spark.sql("select l, d from test").head()
-            self.assertEqual(1, row.asDict()["l"][0].a)
-            self.assertEqual(1.0, row.asDict()['d']['key'].c)
-
-    def test_udt(self):
-        from pyspark.sql.types import _parse_datatype_json_string, 
_infer_type, _make_type_verifier
-        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
-
-        def check_datatype(datatype):
-            pickled = pickle.loads(pickle.dumps(datatype))
-            assert datatype == pickled
-            scala_datatype = 
self.spark._jsparkSession.parseDataType(datatype.json())
-            python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
-            assert datatype == python_datatype
-
-        check_datatype(ExamplePointUDT())
-        structtype_with_udt = StructType([StructField("label", DoubleType(), 
False),
-                                          StructField("point", 
ExamplePointUDT(), False)])
-        check_datatype(structtype_with_udt)
-        p = ExamplePoint(1.0, 2.0)
-        self.assertEqual(_infer_type(p), ExamplePointUDT())
-        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
-        self.assertRaises(ValueError, lambda: 
_make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
-
-        check_datatype(PythonOnlyUDT())
-        structtype_with_udt = StructType([StructField("label", DoubleType(), 
False),
-                                          StructField("point", 
PythonOnlyUDT(), False)])
-        check_datatype(structtype_with_udt)
-        p = PythonOnlyPoint(1.0, 2.0)
-        self.assertEqual(_infer_type(p), PythonOnlyUDT())
-        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
-        self.assertRaises(
-            ValueError,
-            lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
-
-    def test_simple_udt_in_df(self):
-        schema = StructType().add("key", LongType()).add("val", 
PythonOnlyUDT())
-        df = self.spark.createDataFrame(
-            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
-            schema=schema)
-        df.collect()
-
-    def test_nested_udt_in_df(self):
-        schema = StructType().add("key", LongType()).add("val", 
ArrayType(PythonOnlyUDT()))
-        df = self.spark.createDataFrame(
-            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in 
range(10)],
-            schema=schema)
-        df.collect()
-
-        schema = StructType().add("key", LongType()).add("val",
-                                                         MapType(LongType(), 
PythonOnlyUDT()))
-        df = self.spark.createDataFrame(
-            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for 
i in range(10)],
-            schema=schema)
-        df.collect()
-
-    def test_complex_nested_udt_in_df(self):
-        from pyspark.sql.functions import udf
-
-        schema = StructType().add("key", LongType()).add("val", 
PythonOnlyUDT())
-        df = self.spark.createDataFrame(
-            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
-            schema=schema)
-        df.collect()
-
-        gd = df.groupby("key").agg({"val": "collect_list"})
-        gd.collect()
-        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
-        gd.select(udf(*gd)).collect()
-
-    def test_udt_with_none(self):
-        df = self.spark.range(0, 10, 1, 1)
-
-        def myudf(x):
-            if x > 0:
-                return PythonOnlyPoint(float(x), float(x))
-
-        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
-        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
-        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
-
-    def test_nonparam_udf_with_aggregate(self):
-        import pyspark.sql.functions as f
-
-        df = self.spark.createDataFrame([(1, 2), (1, 2)])
-        f_udf = f.udf(lambda: "const_str")
-        rows = df.distinct().withColumn("a", f_udf()).collect()
-        self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
-
-    def test_infer_schema_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.spark.createDataFrame([row])
-        schema = df.schema
-        field = [f for f in schema.fields if f.name == "point"][0]
-        self.assertEqual(type(field.dataType), ExamplePointUDT)
-
-        with self.tempView("labeled_point"):
-            df.createOrReplaceTempView("labeled_point")
-            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
-            self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
-        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df = self.spark.createDataFrame([row])
-        schema = df.schema
-        field = [f for f in schema.fields if f.name == "point"][0]
-        self.assertEqual(type(field.dataType), PythonOnlyUDT)
-
-        with self.tempView("labeled_point"):
-            df.createOrReplaceTempView("labeled_point")
-            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
-            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
-    def test_apply_schema_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-        row = (1.0, ExamplePoint(1.0, 2.0))
-        schema = StructType([StructField("label", DoubleType(), False),
-                             StructField("point", ExamplePointUDT(), False)])
-        df = self.spark.createDataFrame([row], schema)
-        point = df.head().point
-        self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
-        row = (1.0, PythonOnlyPoint(1.0, 2.0))
-        schema = StructType([StructField("label", DoubleType(), False),
-                             StructField("point", PythonOnlyUDT(), False)])
-        df = self.spark.createDataFrame([row], schema)
-        point = df.head().point
-        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
-    def test_udf_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.spark.createDataFrame([row])
-        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
-        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
-        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
-        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), 
ExamplePointUDT())
-        self.assertEqual(ExamplePoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
-
-        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df = self.spark.createDataFrame([row])
-        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
-        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
-        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
-        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 
1), PythonOnlyUDT())
-        self.assertEqual(PythonOnlyPoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
-
-    def test_parquet_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df0 = self.spark.createDataFrame([row])
-        output_dir = os.path.join(self.tempdir.name, "labeled_point")
-        df0.write.parquet(output_dir)
-        df1 = self.spark.read.parquet(output_dir)
-        point = df1.head().point
-        self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
-        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df0 = self.spark.createDataFrame([row])
-        df0.write.parquet(output_dir, mode='overwrite')
-        df1 = self.spark.read.parquet(output_dir)
-        point = df1.head().point
-        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
-    def test_union_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-        row1 = (1.0, ExamplePoint(1.0, 2.0))
-        row2 = (2.0, ExamplePoint(3.0, 4.0))
-        schema = StructType([StructField("label", DoubleType(), False),
-                             StructField("point", ExamplePointUDT(), False)])
-        df1 = self.spark.createDataFrame([row1], schema)
-        df2 = self.spark.createDataFrame([row2], schema)
-
-        result = df1.union(df2).orderBy("label").collect()
-        self.assertEqual(
-            result,
-            [
-                Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
-                Row(label=2.0, point=ExamplePoint(3.0, 4.0))
-            ]
-        )
-
-    def test_cast_to_string_with_udt(self):
-        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
-        from pyspark.sql.functions import col
-        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
-        schema = StructType([StructField("point", ExamplePointUDT(), False),
-                             StructField("pypoint", PythonOnlyUDT(), False)])
-        df = self.spark.createDataFrame([row], schema)
-
-        result = df.select(col('point').cast('string'), 
col('pypoint').cast('string')).head()
-        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 
4.0]'))
-
-    def test_column_operators(self):
-        ci = self.df.key
-        cs = self.df.value
-        c = ci == cs
-        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
-        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci 
** 1)
-        self.assertTrue(all(isinstance(c, Column) for c in rcc))
-        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
-        self.assertTrue(all(isinstance(c, Column) for c in cb))
-        cbool = (ci & ci), (ci | ci), (~ci)
-        self.assertTrue(all(isinstance(c, Column) for c in cbool))
-        css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), 
cs.desc(),\
-            cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
-        self.assertTrue(all(isinstance(c, Column) for c in css))
-        self.assertTrue(isinstance(ci.cast(LongType()), Column))
-        self.assertRaisesRegexp(ValueError,
-                                "Cannot apply 'in' operator against a column",
-                                lambda: 1 in cs)
-
-    def test_column_getitem(self):
-        from pyspark.sql.functions import col
-
-        self.assertIsInstance(col("foo")[1:3], Column)
-        self.assertIsInstance(col("foo")[0], Column)
-        self.assertIsInstance(col("foo")["bar"], Column)
-        self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
-
-    def test_column_select(self):
-        df = self.df
-        self.assertEqual(self.testData, df.select("*").collect())
-        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
-        self.assertEqual([Row(value='1')], df.where(df.key == 
1).select(df.value).collect())
-
-    def test_freqItems(self):
-        vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i 
in range(100)]
-        df = self.sc.parallelize(vals).toDF()
-        items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
-        self.assertTrue(1 in items[0])
-        self.assertTrue(-2.0 in items[1])
-
-    def test_aggregator(self):
-        df = self.df
-        g = df.groupBy()
-        self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 
'count'}).collect()[0]))
-        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
-
-        from pyspark.sql import functions
-        self.assertEqual((0, u'99'),
-                         tuple(g.agg(functions.first(df.key), 
functions.last(df.value)).first()))
-        self.assertTrue(95 < 
g.agg(functions.approx_count_distinct(df.key)).first()[0])
-        self.assertEqual(100, 
g.agg(functions.countDistinct(df.value)).first()[0])
-
-    def test_first_last_ignorenulls(self):
-        from pyspark.sql import functions
-        df = self.spark.range(0, 100)
-        df2 = df.select(functions.when(df.id % 3 == 0, 
None).otherwise(df.id).alias("id"))
-        df3 = df2.select(functions.first(df2.id, False).alias('a'),
-                         functions.first(df2.id, True).alias('b'),
-                         functions.last(df2.id, False).alias('c'),
-                         functions.last(df2.id, True).alias('d'))
-        self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
-
-    def test_approxQuantile(self):
-        df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
-        for f in ["a", u"a"]:
-            aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
-            self.assertTrue(isinstance(aq, list))
-            self.assertEqual(len(aq), 3)
-        self.assertTrue(all(isinstance(q, float) for q in aq))
-        aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
-        self.assertTrue(isinstance(aqs, list))
-        self.assertEqual(len(aqs), 2)
-        self.assertTrue(isinstance(aqs[0], list))
-        self.assertEqual(len(aqs[0]), 3)
-        self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
-        self.assertTrue(isinstance(aqs[1], list))
-        self.assertEqual(len(aqs[1]), 3)
-        self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
-        aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
-        self.assertTrue(isinstance(aqt, list))
-        self.assertEqual(len(aqt), 2)
-        self.assertTrue(isinstance(aqt[0], list))
-        self.assertEqual(len(aqt[0]), 3)
-        self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
-        self.assertTrue(isinstance(aqt[1], list))
-        self.assertEqual(len(aqt[1]), 3)
-        self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
-        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, 
[0.1, 0.9], 0.1))
-        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 
123), [0.1, 0.9], 0.1))
-        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 
123], [0.1, 0.9], 0.1))
-
-    def test_corr(self):
-        import math
-        df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in 
range(10)]).toDF()
-        corr = df.stat.corr(u"a", "b")
-        self.assertTrue(abs(corr - 0.95734012) < 1e-6)
-
-    def test_sampleby(self):
-        df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in 
range(10)]).toDF()
-        sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
-        self.assertTrue(sampled.count() == 3)
-
-    def test_cov(self):
-        df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
-        cov = df.stat.cov(u"a", "b")
-        self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
-
-    def test_crosstab(self):
-        df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 
7)]).toDF()
-        ct = df.stat.crosstab(u"a", "b").collect()
-        ct = sorted(ct, key=lambda x: x[0])
-        for i, row in enumerate(ct):
-            self.assertEqual(row[0], str(i))
-            self.assertTrue(row[1], 1)
-            self.assertTrue(row[2], 1)
-
-    def test_math_functions(self):
-        df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
-        from pyspark.sql import functions
-        import math
-
-        def get_values(l):
-            return [j[0] for j in l]
-
-        def assert_close(a, b):
-            c = get_values(b)
-            diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
-            return sum(diff) == len(a)
-        assert_close([math.cos(i) for i in range(10)],
-                     df.select(functions.cos(df.a)).collect())
-        assert_close([math.cos(i) for i in range(10)],
-                     df.select(functions.cos("a")).collect())
-        assert_close([math.sin(i) for i in range(10)],
-                     df.select(functions.sin(df.a)).collect())
-        assert_close([math.sin(i) for i in range(10)],
-                     df.select(functions.sin(df['a'])).collect())
-        assert_close([math.pow(i, 2 * i) for i in range(10)],
-                     df.select(functions.pow(df.a, df.b)).collect())
-        assert_close([math.pow(i, 2) for i in range(10)],
-                     df.select(functions.pow(df.a, 2)).collect())
-        assert_close([math.pow(i, 2) for i in range(10)],
-                     df.select(functions.pow(df.a, 2.0)).collect())
-        assert_close([math.hypot(i, 2 * i) for i in range(10)],
-                     df.select(functions.hypot(df.a, df.b)).collect())
-
-    def test_rand_functions(self):
-        df = self.df
-        from pyspark.sql import functions
-        rnd = df.select('key', functions.rand()).collect()
-        for row in rnd:
-            assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
-        rndn = df.select('key', functions.randn(5)).collect()
-        for row in rndn:
-            assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
-
-        # If the specified seed is 0, we should use it.
-        # https://issues.apache.org/jira/browse/SPARK-9691
-        rnd1 = df.select('key', functions.rand(0)).collect()
-        rnd2 = df.select('key', functions.rand(0)).collect()
-        self.assertEqual(sorted(rnd1), sorted(rnd2))
-
-        rndn1 = df.select('key', functions.randn(0)).collect()
-        rndn2 = df.select('key', functions.randn(0)).collect()
-        self.assertEqual(sorted(rndn1), sorted(rndn2))
-
-    def test_string_functions(self):
-        from pyspark.sql.functions import col, lit
-        df = self.spark.createDataFrame([['nick']], schema=['name'])
-        self.assertRaisesRegexp(
-            TypeError,
-            "must be the same type",
-            lambda: df.select(col('name').substr(0, lit(1))))
-        if sys.version_info.major == 2:
-            self.assertRaises(
-                TypeError,
-                lambda: df.select(col('name').substr(long(0), long(1))))
-
-    def test_array_contains_function(self):
-        from pyspark.sql.functions import array_contains
-
-        df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
-        actual = df.select(array_contains(df.data, "1").alias('b')).collect()
-        self.assertEqual([Row(b=True), Row(b=False)], actual)
-
-    def test_between_function(self):
-        df = self.sc.parallelize([
-            Row(a=1, b=2, c=3),
-            Row(a=2, b=1, c=3),
-            Row(a=4, b=1, c=4)]).toDF()
-        self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
-                         df.filter(df.a.between(df.b, df.c)).collect())
-
-    def test_struct_type(self):
-        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
-        struct2 = StructType([StructField("f1", StringType(), True),
-                              StructField("f2", StringType(), True, None)])
-        self.assertEqual(struct1.fieldNames(), struct2.names)
-        self.assertEqual(struct1, struct2)
-
-        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
-        struct2 = StructType([StructField("f1", StringType(), True)])
-        self.assertNotEqual(struct1.fieldNames(), struct2.names)
-        self.assertNotEqual(struct1, struct2)
-
-        struct1 = (StructType().add(StructField("f1", StringType(), True))
-                   .add(StructField("f2", StringType(), True, None)))
-        struct2 = StructType([StructField("f1", StringType(), True),
-                              StructField("f2", StringType(), True, None)])
-        self.assertEqual(struct1.fieldNames(), struct2.names)
-        self.assertEqual(struct1, struct2)
-
-        struct1 = (StructType().add(StructField("f1", StringType(), True))
-                   .add(StructField("f2", StringType(), True, None)))
-        struct2 = StructType([StructField("f1", StringType(), True)])
-        self.assertNotEqual(struct1.fieldNames(), struct2.names)
-        self.assertNotEqual(struct1, struct2)
-
-        # Catch exception raised during improper construction
-        self.assertRaises(ValueError, lambda: StructType().add("name"))
-
-        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
-        for field in struct1:
-            self.assertIsInstance(field, StructField)
-
-        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
-        self.assertEqual(len(struct1), 2)
-
-        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
-        self.assertIs(struct1["f1"], struct1.fields[0])
-        self.assertIs(struct1[0], struct1.fields[0])
-        self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
-        self.assertRaises(KeyError, lambda: struct1["f9"])
-        self.assertRaises(IndexError, lambda: struct1[9])
-        self.assertRaises(TypeError, lambda: struct1[9.9])
-
-    def test_parse_datatype_string(self):
-        from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
-        for k, t in _all_atomic_types.items():
-            if t != NullType:
-                self.assertEqual(t(), _parse_datatype_string(k))
-        self.assertEqual(IntegerType(), _parse_datatype_string("int"))
-        self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1  
,1)"))
-        self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 
10,1 )"))
-        self.assertEqual(DecimalType(11, 1), 
_parse_datatype_string("decimal(11,1)"))
-        self.assertEqual(
-            ArrayType(IntegerType()),
-            _parse_datatype_string("array<int >"))
-        self.assertEqual(
-            MapType(IntegerType(), DoubleType()),
-            _parse_datatype_string("map< int, double  >"))
-        self.assertEqual(
-            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
-            _parse_datatype_string("struct<a:int, c:double >"))
-        self.assertEqual(
-            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
-            _parse_datatype_string("a:int, c:double"))
-        self.assertEqual(
-            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
-            _parse_datatype_string("a INT, c DOUBLE"))
-
-    def test_metadata_null(self):
-        schema = StructType([StructField("f1", StringType(), True, None),
-                             StructField("f2", StringType(), True, {'a': 
None})])
-        rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
-        self.spark.createDataFrame(rdd, schema)
-
-    def test_save_and_load(self):
-        df = self.df
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        df.write.json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        schema = StructType([StructField("value", StringType(), True)])
-        actual = self.spark.read.json(tmpPath, schema)
-        self.assertEqual(sorted(df.select("value").collect()), 
sorted(actual.collect()))
-
-        df.write.json(tmpPath, "overwrite")
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        df.write.save(format="json", mode="overwrite", path=tmpPath,
-                      noUse="this options will not be used in save.")
-        actual = self.spark.read.load(format="json", path=tmpPath,
-                                      noUse="this options will not be used in 
load.")
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        defaultDataSourceName = 
self.spark.conf.get("spark.sql.sources.default",
-                                                    
"org.apache.spark.sql.parquet")
-        self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.spark.read.load(path=tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.spark.sql("SET spark.sql.sources.default=" + 
defaultDataSourceName)
-
-        csvpath = os.path.join(tempfile.mkdtemp(), 'data')
-        df.write.option('quote', None).format('csv').save(csvpath)
-
-        shutil.rmtree(tmpPath)
-
-    def test_save_and_load_builder(self):
-        df = self.df
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        df.write.json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        schema = StructType([StructField("value", StringType(), True)])
-        actual = self.spark.read.json(tmpPath, schema)
-        self.assertEqual(sorted(df.select("value").collect()), 
sorted(actual.collect()))
-
-        df.write.mode("overwrite").json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        df.write.mode("overwrite").options(noUse="this options will not be 
used in save.")\
-                .option("noUse", "this option will not be used in save.")\
-                .format("json").save(path=tmpPath)
-        actual =\
-            self.spark.read.format("json")\
-                           .load(path=tmpPath, noUse="this options will not be 
used in load.")
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        defaultDataSourceName = 
self.spark.conf.get("spark.sql.sources.default",
-                                                    
"org.apache.spark.sql.parquet")
-        self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.spark.read.load(path=tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.spark.sql("SET spark.sql.sources.default=" + 
defaultDataSourceName)
-
-        shutil.rmtree(tmpPath)
-
-    def test_stream_trigger(self):
-        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-
-        # Should take at least one arg
-        try:
-            df.writeStream.trigger()
-        except ValueError:
-            pass
-
-        # Should not take multiple args
-        try:
-            df.writeStream.trigger(once=True, processingTime='5 seconds')
-        except ValueError:
-            pass
-
-        # Should not take multiple args
-        try:
-            df.writeStream.trigger(processingTime='5 seconds', continuous='1 
second')
-        except ValueError:
-            pass
-
-        # Should take only keyword args
-        try:
-            df.writeStream.trigger('5 seconds')
-            self.fail("Should have thrown an exception")
-        except TypeError:
-            pass
-
-    def test_stream_read_options(self):
-        schema = StructType([StructField("data", StringType(), False)])
-        df = self.spark.readStream\
-            .format('text')\
-            .option('path', 'python/test_support/sql/streaming')\
-            .schema(schema)\
-            .load()
-        self.assertTrue(df.isStreaming)
-        self.assertEqual(df.schema.simpleString(), "struct<data:string>")
-
-    def test_stream_read_options_overwrite(self):
-        bad_schema = StructType([StructField("test", IntegerType(), False)])
-        schema = StructType([StructField("data", StringType(), False)])
-        df = self.spark.readStream.format('csv').option('path', 
'python/test_support/sql/fake') \
-            .schema(bad_schema)\
-            .load(path='python/test_support/sql/streaming', schema=schema, 
format='text')
-        self.assertTrue(df.isStreaming)
-        self.assertEqual(df.schema.simpleString(), "struct<data:string>")
-
-    def test_stream_save_options(self):
-        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
-            .withColumn('id', lit(1))
-        for q in self.spark._wrapped.streams.active:
-            q.stop()
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        self.assertTrue(df.isStreaming)
-        out = os.path.join(tmpPath, 'out')
-        chk = os.path.join(tmpPath, 'chk')
-        q = df.writeStream.option('checkpointLocation', 
chk).queryName('this_query') \
-            
.format('parquet').partitionBy('id').outputMode('append').option('path', 
out).start()
-        try:
-            self.assertEqual(q.name, 'this_query')
-            self.assertTrue(q.isActive)
-            q.processAllAvailable()
-            output_files = []
-            for _, _, files in os.walk(out):
-                output_files.extend([f for f in files if not 
f.startswith('.')])
-            self.assertTrue(len(output_files) > 0)
-            self.assertTrue(len(os.listdir(chk)) > 0)
-        finally:
-            q.stop()
-            shutil.rmtree(tmpPath)
-
-    def test_stream_save_options_overwrite(self):
-        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-        for q in self.spark._wrapped.streams.active:
-            q.stop()
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        self.assertTrue(df.isStreaming)
-        out = os.path.join(tmpPath, 'out')
-        chk = os.path.join(tmpPath, 'chk')
-        fake1 = os.path.join(tmpPath, 'fake1')
-        fake2 = os.path.join(tmpPath, 'fake2')
-        q = df.writeStream.option('checkpointLocation', fake1)\
-            .format('memory').option('path', fake2) \
-            .queryName('fake_query').outputMode('append') \
-            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
-
-        try:
-            self.assertEqual(q.name, 'this_query')
-            self.assertTrue(q.isActive)
-            q.processAllAvailable()
-            output_files = []
-            for _, _, files in os.walk(out):
-                output_files.extend([f for f in files if not 
f.startswith('.')])
-            self.assertTrue(len(output_files) > 0)
-            self.assertTrue(len(os.listdir(chk)) > 0)
-            self.assertFalse(os.path.isdir(fake1))  # should not have been 
created
-            self.assertFalse(os.path.isdir(fake2))  # should not have been 
created
-        finally:
-            q.stop()
-            shutil.rmtree(tmpPath)
-
-    def test_stream_status_and_progress(self):
-        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-        for q in self.spark._wrapped.streams.active:
-            q.stop()
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        self.assertTrue(df.isStreaming)
-        out = os.path.join(tmpPath, 'out')
-        chk = os.path.join(tmpPath, 'chk')
-
-        def func(x):
-            time.sleep(1)
-            return x
-
-        from pyspark.sql.functions import col, udf
-        sleep_udf = udf(func)
-
-        # Use "sleep_udf" to delay the progress update so that we can test 
`lastProgress` when there
-        # were no updates.
-        q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
-            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
-        try:
-            # "lastProgress" will return None in most cases. However, as it 
may be flaky when
-            # Jenkins is very slow, we don't assert it. If there is something 
wrong, "lastProgress"
-            # may throw error with a high chance and make this test flaky, so 
we should still be
-            # able to detect broken codes.
-            q.lastProgress
-
-            q.processAllAvailable()
-            lastProgress = q.lastProgress
-            recentProgress = q.recentProgress
-            status = q.status
-            self.assertEqual(lastProgress['name'], q.name)
-            self.assertEqual(lastProgress['id'], q.id)
-            self.assertTrue(any(p == lastProgress for p in recentProgress))
-            self.assertTrue(
-                "message" in status and
-                "isDataAvailable" in status and
-                "isTriggerActive" in status)
-        finally:
-            q.stop()
-            shutil.rmtree(tmpPath)
-
-    def test_stream_await_termination(self):
-        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-        for q in self.spark._wrapped.streams.active:
-            q.stop()
-        tmpPath = tempfile.mkdtemp()
-        shutil.rmtree(tmpPath)
-        self.assertTrue(df.isStreaming)
-        out = os.path.join(tmpPath, 'out')
-        chk = os.path.join(tmpPath, 'chk')
-        q = df.writeStream\
-            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
-        try:
-            self.assertTrue(q.isActive)
-            try:
-                q.awaitTermination("hello")
-                self.fail("Expected a value exception")
-            except ValueError:
-                pass
-            now = time.time()
-            # test should take at least 2 seconds
-            res = q.awaitTermination(2.6)
-            duration = time.time() - now
-            self.assertTrue(duration >= 2)
-            self.assertFalse(res)
-        finally:
-            q.stop()
-            shutil.rmtree(tmpPath)
-
-    def test_stream_exception(self):
-        sdf = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-        sq = sdf.writeStream.format('memory'

<TRUNCATED>

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to