http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/__init__.py 
b/python/pyspark/sql/tests/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/python/pyspark/sql/tests/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_appsubmit.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_appsubmit.py 
b/python/pyspark/sql/tests/test_appsubmit.py
new file mode 100644
index 0000000..3c71151
--- /dev/null
+++ b/python/pyspark/sql/tests/test_appsubmit.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import subprocess
+import tempfile
+
+import py4j
+
+from pyspark import SparkContext
+from pyspark.tests import SparkSubmitTests
+
+
+class HiveSparkSubmitTests(SparkSubmitTests):
+
+    @classmethod
+    def setUpClass(cls):
+        # get a SparkContext to check for availability of Hive
+        sc = SparkContext('local[4]', cls.__name__)
+        cls.hive_available = True
+        try:
+            sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+        except py4j.protocol.Py4JError:
+            cls.hive_available = False
+        except TypeError:
+            cls.hive_available = False
+        finally:
+            # we don't need this SparkContext for the test
+            sc.stop()
+
+    def setUp(self):
+        super(HiveSparkSubmitTests, self).setUp()
+        if not self.hive_available:
+            self.skipTest("Hive is not available.")
+
+    def test_hivecontext(self):
+        # This test checks that HiveContext is using Hive metastore 
(SPARK-16224).
+        # It sets a metastore url and checks if there is a derby dir created by
+        # Hive metastore. If this derby dir exists, HiveContext is using
+        # Hive metastore.
+        metastore_path = os.path.join(tempfile.mkdtemp(), 
"spark16224_metastore_db")
+        metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + 
";create=true"
+        hive_site_dir = os.path.join(self.programDir, "conf")
+        hive_site_file = self.createTempFile("hive-site.xml", ("""
+            |<configuration>
+            |  <property>
+            |  <name>javax.jdo.option.ConnectionURL</name>
+            |  <value>%s</value>
+            |  </property>
+            |</configuration>
+            """ % metastore_URL).lstrip(), "conf")
+        script = self.createTempFile("test.py", """
+            |import os
+            |
+            |from pyspark.conf import SparkConf
+            |from pyspark.context import SparkContext
+            |from pyspark.sql import HiveContext
+            |
+            |conf = SparkConf()
+            |sc = SparkContext(conf=conf)
+            |hive_context = HiveContext(sc)
+            |print(hive_context.sql("show databases").collect())
+            """)
+        proc = subprocess.Popen(
+            self.sparkSubmit + ["--master", "local-cluster[1,1,1024]",
+                                "--driver-class-path", hive_site_dir, script],
+            stdout=subprocess.PIPE)
+        out, err = proc.communicate()
+        self.assertEqual(0, proc.returncode)
+        self.assertIn("default", out.decode('utf-8'))
+        self.assertTrue(os.path.exists(metastore_path))
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_appsubmit import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_arrow.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
new file mode 100644
index 0000000..44f7035
--- /dev/null
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -0,0 +1,399 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import os
+import threading
+import time
+import unittest
+import warnings
+
+from pyspark.sql import Row
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, 
have_pyarrow, \
+    pandas_requirement_message, pyarrow_requirement_message
+from pyspark.tests import QuietTest
+from pyspark.util import _exception_message
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message)
+class ArrowTests(ReusedSQLTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        from datetime import date, datetime
+        from decimal import Decimal
+        from distutils.version import LooseVersion
+        import pyarrow as pa
+        super(ArrowTests, cls).setUpClass()
+        cls.warnings_lock = threading.Lock()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+        cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+        # Disable fallback by default to easily detect the failures.
+        cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", 
"false")
+        cls.schema = StructType([
+            StructField("1_str_t", StringType(), True),
+            StructField("2_int_t", IntegerType(), True),
+            StructField("3_long_t", LongType(), True),
+            StructField("4_float_t", FloatType(), True),
+            StructField("5_double_t", DoubleType(), True),
+            StructField("6_decimal_t", DecimalType(38, 18), True),
+            StructField("7_date_t", DateType(), True),
+            StructField("8_timestamp_t", TimestampType(), True)])
+        cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
+                     date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+                    (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
+                     date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+                    (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
+                     date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
+
+        # TODO: remove version check once minimum pyarrow version is 0.10.0
+        if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
+            cls.schema.add(StructField("9_binary_t", BinaryType(), True))
+            cls.data[0] = cls.data[0] + (bytearray(b"a"),)
+            cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
+            cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        super(ArrowTests, cls).tearDownClass()
+
+    def create_pandas_data_frame(self):
+        import pandas as pd
+        import numpy as np
+        data_dict = {}
+        for j, name in enumerate(self.schema.names):
+            data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+        # need to convert these to numpy types first
+        data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
+        data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
+        return pd.DataFrame(data=data_dict)
+
+    def test_toPandas_fallback_enabled(self):
+        import pandas as pd
+
+        with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": 
True}):
+            schema = StructType([StructField("map", MapType(StringType(), 
IntegerType()), True)])
+            df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
+            with QuietTest(self.sc):
+                with self.warnings_lock:
+                    with warnings.catch_warnings(record=True) as warns:
+                        # we want the warnings to appear even if this test is 
run from a subclass
+                        warnings.simplefilter("always")
+                        pdf = df.toPandas()
+                        # Catch and check the last UserWarning.
+                        user_warns = [
+                            warn.message for warn in warns if 
isinstance(warn.message, UserWarning)]
+                        self.assertTrue(len(user_warns) > 0)
+                        self.assertTrue(
+                            "Attempting non-optimization" in 
_exception_message(user_warns[-1]))
+                        self.assertPandasEqual(pdf, pd.DataFrame({u'map': 
[{u'a': 1}]}))
+
+    def test_toPandas_fallback_disabled(self):
+        from distutils.version import LooseVersion
+        import pyarrow as pa
+
+        schema = StructType([StructField("map", MapType(StringType(), 
IntegerType()), True)])
+        df = self.spark.createDataFrame([(None,)], schema=schema)
+        with QuietTest(self.sc):
+            with self.warnings_lock:
+                with self.assertRaisesRegexp(Exception, 'Unsupported type'):
+                    df.toPandas()
+
+        # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            schema = StructType([StructField("binary", BinaryType(), True)])
+            df = self.spark.createDataFrame([(None,)], schema=schema)
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(Exception, 'Unsupported 
type.*BinaryType'):
+                    df.toPandas()
+
+    def test_null_conversion(self):
+        df_null = self.spark.createDataFrame([tuple([None for _ in 
range(len(self.data[0]))])] +
+                                             self.data)
+        pdf = df_null.toPandas()
+        null_counts = pdf.isnull().sum().tolist()
+        self.assertTrue(all([c == 1 for c in null_counts]))
+
+    def _toPandas_arrow_toggle(self, df):
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
+            pdf = df.toPandas()
+
+        pdf_arrow = df.toPandas()
+
+        return pdf, pdf_arrow
+
+    def test_toPandas_arrow_toggle(self):
+        df = self.spark.createDataFrame(self.data, schema=self.schema)
+        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+        expected = self.create_pandas_data_frame()
+        self.assertPandasEqual(expected, pdf)
+        self.assertPandasEqual(expected, pdf_arrow)
+
+    def test_toPandas_respect_session_timezone(self):
+        df = self.spark.createDataFrame(self.data, schema=self.schema)
+
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
+            self.assertPandasEqual(pdf_arrow_la, pdf_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
+            pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
+            self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
+
+            self.assertFalse(pdf_ny.equals(pdf_la))
+
+            from pyspark.sql.types import 
_check_series_convert_timestamps_local_tz
+            pdf_la_corrected = pdf_la.copy()
+            for field in self.schema:
+                if isinstance(field.dataType, TimestampType):
+                    pdf_la_corrected[field.name] = 
_check_series_convert_timestamps_local_tz(
+                        pdf_la_corrected[field.name], timezone)
+            self.assertPandasEqual(pdf_ny, pdf_la_corrected)
+
+    def test_pandas_round_trip(self):
+        pdf = self.create_pandas_data_frame()
+        df = self.spark.createDataFrame(self.data, schema=self.schema)
+        pdf_arrow = df.toPandas()
+        self.assertPandasEqual(pdf_arrow, pdf)
+
+    def test_filtered_frame(self):
+        df = self.spark.range(3).toDF("i")
+        pdf = df.filter("i < 0").toPandas()
+        self.assertEqual(len(pdf.columns), 1)
+        self.assertEqual(pdf.columns[0], "i")
+        self.assertTrue(pdf.empty)
+
+    def _createDataFrame_toggle(self, pdf, schema=None):
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
+            df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
+
+        df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+
+        return df_no_arrow, df_arrow
+
+    def test_createDataFrame_toggle(self):
+        pdf = self.create_pandas_data_frame()
+        df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, 
schema=self.schema)
+        self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
+
+    def test_createDataFrame_respect_session_timezone(self):
+        from datetime import timedelta
+        pdf = self.create_pandas_data_frame()
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, 
schema=self.schema)
+            result_la = df_no_arrow_la.collect()
+            result_arrow_la = df_arrow_la.collect()
+            self.assertEqual(result_la, result_arrow_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
+            df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, 
schema=self.schema)
+            result_ny = df_no_arrow_ny.collect()
+            result_arrow_ny = df_arrow_ny.collect()
+            self.assertEqual(result_ny, result_arrow_ny)
+
+            self.assertNotEqual(result_ny, result_la)
+
+            # Correct result_la by adjusting 3 hours difference between Los 
Angeles and New York
+            result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == 
'8_timestamp_t' else v
+                                          for k, v in row.asDict().items()})
+                                   for row in result_la]
+            self.assertEqual(result_ny, result_la_corrected)
+
+    def test_createDataFrame_with_schema(self):
+        pdf = self.create_pandas_data_frame()
+        df = self.spark.createDataFrame(pdf, schema=self.schema)
+        self.assertEquals(self.schema, df.schema)
+        pdf_arrow = df.toPandas()
+        self.assertPandasEqual(pdf_arrow, pdf)
+
+    def test_createDataFrame_with_incorrect_schema(self):
+        pdf = self.create_pandas_data_frame()
+        fields = list(self.schema)
+        fields[0], fields[7] = fields[7], fields[0]  # swap str with timestamp
+        wrong_schema = StructType(fields)
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, ".*No 
cast.*string.*timestamp.*"):
+                self.spark.createDataFrame(pdf, schema=wrong_schema)
+
+    def test_createDataFrame_with_names(self):
+        pdf = self.create_pandas_data_frame()
+        new_names = list(map(str, range(len(self.schema.fieldNames()))))
+        # Test that schema as a list of column names gets applied
+        df = self.spark.createDataFrame(pdf, schema=list(new_names))
+        self.assertEquals(df.schema.fieldNames(), new_names)
+        # Test that schema as tuple of column names gets applied
+        df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
+        self.assertEquals(df.schema.fieldNames(), new_names)
+
+    def test_createDataFrame_column_name_encoding(self):
+        import pandas as pd
+        pdf = pd.DataFrame({u'a': [1]})
+        columns = self.spark.createDataFrame(pdf).columns
+        self.assertTrue(isinstance(columns[0], str))
+        self.assertEquals(columns[0], 'a')
+        columns = self.spark.createDataFrame(pdf, [u'b']).columns
+        self.assertTrue(isinstance(columns[0], str))
+        self.assertEquals(columns[0], 'b')
+
+    def test_createDataFrame_with_single_data_type(self):
+        import pandas as pd
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not 
supported.*"):
+                self.spark.createDataFrame(pd.DataFrame({"a": [1]}), 
schema="int")
+
+    def test_createDataFrame_does_not_modify_input(self):
+        import pandas as pd
+        # Some series get converted for Spark to consume, this makes sure 
input is unchanged
+        pdf = self.create_pandas_data_frame()
+        # Use a nanosecond value to make sure it is not truncated
+        pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
+        # Integers with nulls will get NaNs filled with 0 and will be casted
+        pdf.ix[1, '2_int_t'] = None
+        pdf_copy = pdf.copy(deep=True)
+        self.spark.createDataFrame(pdf, schema=self.schema)
+        self.assertTrue(pdf.equals(pdf_copy))
+
+    def test_schema_conversion_roundtrip(self):
+        from pyspark.sql.types import from_arrow_schema, to_arrow_schema
+        arrow_schema = to_arrow_schema(self.schema)
+        schema_rt = from_arrow_schema(arrow_schema)
+        self.assertEquals(self.schema, schema_rt)
+
+    def test_createDataFrame_with_array_type(self):
+        import pandas as pd
+        pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", 
u"z"]]})
+        df, df_arrow = self._createDataFrame_toggle(pdf)
+        result = df.collect()
+        result_arrow = df_arrow.collect()
+        expected = [tuple(list(e) for e in rec) for rec in 
pdf.to_records(index=False)]
+        for r in range(len(expected)):
+            for e in range(len(expected[r])):
+                self.assertTrue(expected[r][e] == result_arrow[r][e] and
+                                result[r][e] == result_arrow[r][e])
+
+    def test_toPandas_with_array_type(self):
+        expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
+        array_schema = StructType([StructField("a", ArrayType(IntegerType())),
+                                   StructField("b", ArrayType(StringType()))])
+        df = self.spark.createDataFrame(expected, schema=array_schema)
+        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+        result = [tuple(list(e) for e in rec) for rec in 
pdf.to_records(index=False)]
+        result_arrow = [tuple(list(e) for e in rec) for rec in 
pdf_arrow.to_records(index=False)]
+        for r in range(len(expected)):
+            for e in range(len(expected[r])):
+                self.assertTrue(expected[r][e] == result_arrow[r][e] and
+                                result[r][e] == result_arrow[r][e])
+
+    def test_createDataFrame_with_int_col_names(self):
+        import numpy as np
+        import pandas as pd
+        pdf = pd.DataFrame(np.random.rand(4, 2))
+        df, df_arrow = self._createDataFrame_toggle(pdf)
+        pdf_col_names = [str(c) for c in pdf.columns]
+        self.assertEqual(pdf_col_names, df.columns)
+        self.assertEqual(pdf_col_names, df_arrow.columns)
+
+    def test_createDataFrame_fallback_enabled(self):
+        import pandas as pd
+
+        with QuietTest(self.sc):
+            with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": 
True}):
+                with warnings.catch_warnings(record=True) as warns:
+                    # we want the warnings to appear even if this test is run 
from a subclass
+                    warnings.simplefilter("always")
+                    df = self.spark.createDataFrame(
+                        pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
+                    # Catch and check the last UserWarning.
+                    user_warns = [
+                        warn.message for warn in warns if 
isinstance(warn.message, UserWarning)]
+                    self.assertTrue(len(user_warns) > 0)
+                    self.assertTrue(
+                        "Attempting non-optimization" in 
_exception_message(user_warns[-1]))
+                    self.assertEqual(df.collect(), [Row(a={u'a': 1})])
+
+    def test_createDataFrame_fallback_disabled(self):
+        from distutils.version import LooseVersion
+        import pandas as pd
+        import pyarrow as pa
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
+                self.spark.createDataFrame(
+                    pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
+
+        # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(TypeError, 'Unsupported 
type.*BinaryType'):
+                    self.spark.createDataFrame(
+                        pd.DataFrame([[{'a': b'aaa'}]]), "a: binary")
+
+    # Regression test for SPARK-23314
+    def test_timestamp_dst(self):
+        import pandas as pd
+        # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 
am
+        dt = [datetime.datetime(2015, 11, 1, 0, 30),
+              datetime.datetime(2015, 11, 1, 1, 30),
+              datetime.datetime(2015, 11, 1, 2, 30)]
+        pdf = pd.DataFrame({'time': dt})
+
+        df_from_python = self.spark.createDataFrame(dt, 
'timestamp').toDF('time')
+        df_from_pandas = self.spark.createDataFrame(pdf)
+
+        self.assertPandasEqual(pdf, df_from_python.toPandas())
+        self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+
+
+class EncryptionArrowTests(ArrowTests):
+
+    @classmethod
+    def conf(cls):
+        return super(EncryptionArrowTests, 
cls).conf().set("spark.io.encryption.enabled", "true")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_arrow import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_catalog.py 
b/python/pyspark/sql/tests/test_catalog.py
new file mode 100644
index 0000000..23d2577
--- /dev/null
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -0,0 +1,199 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class CatalogTests(ReusedSQLTestCase):
+
+    def test_current_database(self):
+        spark = self.spark
+        with self.database("some_db"):
+            self.assertEquals(spark.catalog.currentDatabase(), "default")
+            spark.sql("CREATE DATABASE some_db")
+            spark.catalog.setCurrentDatabase("some_db")
+            self.assertEquals(spark.catalog.currentDatabase(), "some_db")
+            self.assertRaisesRegexp(
+                AnalysisException,
+                "does_not_exist",
+                lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
+
+    def test_list_databases(self):
+        spark = self.spark
+        with self.database("some_db"):
+            databases = [db.name for db in spark.catalog.listDatabases()]
+            self.assertEquals(databases, ["default"])
+            spark.sql("CREATE DATABASE some_db")
+            databases = [db.name for db in spark.catalog.listDatabases()]
+            self.assertEquals(sorted(databases), ["default", "some_db"])
+
+    def test_list_tables(self):
+        from pyspark.sql.catalog import Table
+        spark = self.spark
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            with self.table("tab1", "some_db.tab2"):
+                with self.tempView("temp_tab"):
+                    self.assertEquals(spark.catalog.listTables(), [])
+                    self.assertEquals(spark.catalog.listTables("some_db"), [])
+                    spark.createDataFrame([(1, 
1)]).createOrReplaceTempView("temp_tab")
+                    spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING 
parquet")
+                    spark.sql("CREATE TABLE some_db.tab2 (name STRING, age 
INT) USING parquet")
+                    tables = sorted(spark.catalog.listTables(), key=lambda t: 
t.name)
+                    tablesDefault = \
+                        sorted(spark.catalog.listTables("default"), key=lambda 
t: t.name)
+                    tablesSomeDb = \
+                        sorted(spark.catalog.listTables("some_db"), key=lambda 
t: t.name)
+                    self.assertEquals(tables, tablesDefault)
+                    self.assertEquals(len(tables), 2)
+                    self.assertEquals(len(tablesSomeDb), 2)
+                    self.assertEquals(tables[0], Table(
+                        name="tab1",
+                        database="default",
+                        description=None,
+                        tableType="MANAGED",
+                        isTemporary=False))
+                    self.assertEquals(tables[1], Table(
+                        name="temp_tab",
+                        database=None,
+                        description=None,
+                        tableType="TEMPORARY",
+                        isTemporary=True))
+                    self.assertEquals(tablesSomeDb[0], Table(
+                        name="tab2",
+                        database="some_db",
+                        description=None,
+                        tableType="MANAGED",
+                        isTemporary=False))
+                    self.assertEquals(tablesSomeDb[1], Table(
+                        name="temp_tab",
+                        database=None,
+                        description=None,
+                        tableType="TEMPORARY",
+                        isTemporary=True))
+                    self.assertRaisesRegexp(
+                        AnalysisException,
+                        "does_not_exist",
+                        lambda: spark.catalog.listTables("does_not_exist"))
+
+    def test_list_functions(self):
+        from pyspark.sql.catalog import Function
+        spark = self.spark
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            functions = dict((f.name, f) for f in 
spark.catalog.listFunctions())
+            functionsDefault = dict((f.name, f) for f in 
spark.catalog.listFunctions("default"))
+            self.assertTrue(len(functions) > 200)
+            self.assertTrue("+" in functions)
+            self.assertTrue("like" in functions)
+            self.assertTrue("month" in functions)
+            self.assertTrue("to_date" in functions)
+            self.assertTrue("to_timestamp" in functions)
+            self.assertTrue("to_unix_timestamp" in functions)
+            self.assertTrue("current_database" in functions)
+            self.assertEquals(functions["+"], Function(
+                name="+",
+                description=None,
+                className="org.apache.spark.sql.catalyst.expressions.Add",
+                isTemporary=True))
+            self.assertEquals(functions, functionsDefault)
+
+            with self.function("func1", "some_db.func2"):
+                spark.catalog.registerFunction("temp_func", lambda x: str(x))
+                spark.sql("CREATE FUNCTION func1 AS 
'org.apache.spark.data.bricks'")
+                spark.sql("CREATE FUNCTION some_db.func2 AS 
'org.apache.spark.data.bricks'")
+                newFunctions = dict((f.name, f) for f in 
spark.catalog.listFunctions())
+                newFunctionsSomeDb = \
+                    dict((f.name, f) for f in 
spark.catalog.listFunctions("some_db"))
+                self.assertTrue(set(functions).issubset(set(newFunctions)))
+                
self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
+                self.assertTrue("temp_func" in newFunctions)
+                self.assertTrue("func1" in newFunctions)
+                self.assertTrue("func2" not in newFunctions)
+                self.assertTrue("temp_func" in newFunctionsSomeDb)
+                self.assertTrue("func1" not in newFunctionsSomeDb)
+                self.assertTrue("func2" in newFunctionsSomeDb)
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "does_not_exist",
+                    lambda: spark.catalog.listFunctions("does_not_exist"))
+
+    def test_list_columns(self):
+        from pyspark.sql.catalog import Column
+        spark = self.spark
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            with self.table("tab1", "some_db.tab2"):
+                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING 
parquet")
+                spark.sql(
+                    "CREATE TABLE some_db.tab2 (nickname STRING, tolerance 
FLOAT) USING parquet")
+                columns = sorted(spark.catalog.listColumns("tab1"), key=lambda 
c: c.name)
+                columnsDefault = \
+                    sorted(spark.catalog.listColumns("tab1", "default"), 
key=lambda c: c.name)
+                self.assertEquals(columns, columnsDefault)
+                self.assertEquals(len(columns), 2)
+                self.assertEquals(columns[0], Column(
+                    name="age",
+                    description=None,
+                    dataType="int",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertEquals(columns[1], Column(
+                    name="name",
+                    description=None,
+                    dataType="string",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                columns2 = \
+                    sorted(spark.catalog.listColumns("tab2", "some_db"), 
key=lambda c: c.name)
+                self.assertEquals(len(columns2), 2)
+                self.assertEquals(columns2[0], Column(
+                    name="nickname",
+                    description=None,
+                    dataType="string",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertEquals(columns2[1], Column(
+                    name="tolerance",
+                    description=None,
+                    dataType="float",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "tab2",
+                    lambda: spark.catalog.listColumns("tab2"))
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "does_not_exist",
+                    lambda: spark.catalog.listColumns("does_not_exist"))
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_catalog import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_column.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_column.py 
b/python/pyspark/sql/tests/test_column.py
new file mode 100644
index 0000000..faadde9
--- /dev/null
+++ b/python/pyspark/sql/tests/test_column.py
@@ -0,0 +1,157 @@
+# -*- encoding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+from pyspark.sql import Column, Row
+from pyspark.sql.types import *
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ColumnTests(ReusedSQLTestCase):
+
+    def test_column_name_encoding(self):
+        """Ensure that created columns has `str` type consistently."""
+        columns = self.spark.createDataFrame([('Alice', 1)], ['name', 
u'age']).columns
+        self.assertEqual(columns, ['name', 'age'])
+        self.assertTrue(isinstance(columns[0], str))
+        self.assertTrue(isinstance(columns[1], str))
+
+    def test_and_in_expression(self):
+        self.assertEqual(4, self.df.filter((self.df.key <= 10) & 
(self.df.value <= "2")).count())
+        self.assertRaises(ValueError, lambda: (self.df.key <= 10) and 
(self.df.value <= "2"))
+        self.assertEqual(14, self.df.filter((self.df.key <= 3) | 
(self.df.value < "2")).count())
+        self.assertRaises(ValueError, lambda: self.df.key <= 3 or 
self.df.value < "2")
+        self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
+        self.assertRaises(ValueError, lambda: not self.df.key == 1)
+
+    def test_validate_column_types(self):
+        from pyspark.sql.functions import udf, to_json
+        from pyspark.sql.column import _to_java_column
+
+        self.assertTrue("Column" in _to_java_column("a").getClass().toString())
+        self.assertTrue("Column" in 
_to_java_column(u"a").getClass().toString())
+        self.assertTrue("Column" in 
_to_java_column(self.spark.range(1).id).getClass().toString())
+
+        self.assertRaisesRegexp(
+            TypeError,
+            "Invalid argument, not a string or column",
+            lambda: _to_java_column(1))
+
+        class A():
+            pass
+
+        self.assertRaises(TypeError, lambda: _to_java_column(A()))
+        self.assertRaises(TypeError, lambda: _to_java_column([]))
+
+        self.assertRaisesRegexp(
+            TypeError,
+            "Invalid argument, not a string or column",
+            lambda: udf(lambda x: x)(None))
+        self.assertRaises(TypeError, lambda: to_json(1))
+
+    def test_column_operators(self):
+        ci = self.df.key
+        cs = self.df.value
+        c = ci == cs
+        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci 
** 1)
+        self.assertTrue(all(isinstance(c, Column) for c in rcc))
+        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
+        self.assertTrue(all(isinstance(c, Column) for c in cb))
+        cbool = (ci & ci), (ci | ci), (~ci)
+        self.assertTrue(all(isinstance(c, Column) for c in cbool))
+        css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), 
cs.desc(),\
+            cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
+        self.assertTrue(all(isinstance(c, Column) for c in css))
+        self.assertTrue(isinstance(ci.cast(LongType()), Column))
+        self.assertRaisesRegexp(ValueError,
+                                "Cannot apply 'in' operator against a column",
+                                lambda: 1 in cs)
+
+    def test_column_getitem(self):
+        from pyspark.sql.functions import col
+
+        self.assertIsInstance(col("foo")[1:3], Column)
+        self.assertIsInstance(col("foo")[0], Column)
+        self.assertIsInstance(col("foo")["bar"], Column)
+        self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
+
+    def test_column_select(self):
+        df = self.df
+        self.assertEqual(self.testData, df.select("*").collect())
+        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+        self.assertEqual([Row(value='1')], df.where(df.key == 
1).select(df.value).collect())
+
+    def test_access_column(self):
+        df = self.df
+        self.assertTrue(isinstance(df.key, Column))
+        self.assertTrue(isinstance(df['key'], Column))
+        self.assertTrue(isinstance(df[0], Column))
+        self.assertRaises(IndexError, lambda: df[2])
+        self.assertRaises(AnalysisException, lambda: df["bad_key"])
+        self.assertRaises(TypeError, lambda: df[{}])
+
+    def test_column_name_with_non_ascii(self):
+        if sys.version >= '3':
+            columnName = "数量"
+            self.assertTrue(isinstance(columnName, str))
+        else:
+            columnName = unicode("数量", "utf-8")
+            self.assertTrue(isinstance(columnName, unicode))
+        schema = StructType([StructField(columnName, LongType(), True)])
+        df = self.spark.createDataFrame([(1,)], schema)
+        self.assertEqual(schema, df.schema)
+        self.assertEqual("DataFrame[数量: bigint]", str(df))
+        self.assertEqual([("数量", 'bigint')], df.dtypes)
+        self.assertEqual(1, df.select("数量").first()[0])
+        self.assertEqual(1, df.select(df["数量"]).first()[0])
+
+    def test_field_accessor(self):
+        df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": 
"v"})]).toDF()
+        self.assertEqual(1, df.select(df.l[0]).first()[0])
+        self.assertEqual(1, df.select(df.r["a"]).first()[0])
+        self.assertEqual(1, df.select(df["r.a"]).first()[0])
+        self.assertEqual("b", df.select(df.r["b"]).first()[0])
+        self.assertEqual("b", df.select(df["r.b"]).first()[0])
+        self.assertEqual("v", df.select(df.d["k"]).first()[0])
+
+    def test_bitwise_operations(self):
+        from pyspark.sql import functions
+        row = Row(a=170, b=75)
+        df = self.spark.createDataFrame([row])
+        result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
+        self.assertEqual(170 & 75, result['(a & b)'])
+        result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
+        self.assertEqual(170 | 75, result['(a | b)'])
+        result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
+        self.assertEqual(170 ^ 75, result['(a ^ b)'])
+        result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
+        self.assertEqual(~75, result['~b'])
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_column import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_conf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_conf.py 
b/python/pyspark/sql/tests/test_conf.py
new file mode 100644
index 0000000..f5d68a8
--- /dev/null
+++ b/python/pyspark/sql/tests/test_conf.py
@@ -0,0 +1,55 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ConfTests(ReusedSQLTestCase):
+
+    def test_conf(self):
+        spark = self.spark
+        spark.conf.set("bogo", "sipeo")
+        self.assertEqual(spark.conf.get("bogo"), "sipeo")
+        spark.conf.set("bogo", "ta")
+        self.assertEqual(spark.conf.get("bogo"), "ta")
+        self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
+        self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
+        self.assertRaisesRegexp(Exception, "not.set", lambda: 
spark.conf.get("not.set"))
+        spark.conf.unset("bogo")
+        self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
+
+        self.assertEqual(spark.conf.get("hyukjin", None), None)
+
+        # This returns 'STATIC' because it's the default value of
+        # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in
+        # `spark.conf.get` is unset.
+        
self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), 
"STATIC")
+
+        # This returns None because 'spark.sql.sources.partitionOverwriteMode' 
is unset, but
+        # `defaultValue` in `spark.conf.get` is set to None.
+        
self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", 
None), None)
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_conf import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_context.py 
b/python/pyspark/sql/tests/test_context.py
new file mode 100644
index 0000000..d9d408a
--- /dev/null
+++ b/python/pyspark/sql/tests/test_context.py
@@ -0,0 +1,263 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+
+import py4j
+
+from pyspark import HiveContext, Row
+from pyspark.sql.types import *
+from pyspark.sql.window import Window
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class HiveContextSQLTests(ReusedPySparkTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        cls.hive_available = True
+        try:
+            cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+        except py4j.protocol.Py4JError:
+            cls.hive_available = False
+        except TypeError:
+            cls.hive_available = False
+        os.unlink(cls.tempdir.name)
+        if cls.hive_available:
+            cls.spark = HiveContext._createForTesting(cls.sc)
+            cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+            cls.df = cls.sc.parallelize(cls.testData).toDF()
+
+    def setUp(self):
+        if not self.hive_available:
+            self.skipTest("Hive is not available.")
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+    def test_save_and_load_table(self):
+        df = self.df
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
+        actual = self.spark.createExternalTable("externalJsonTable", tmpPath, 
"json")
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
savedJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
externalJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+        self.spark.sql("DROP TABLE externalJsonTable")
+
+        df.write.saveAsTable("savedJsonTable", "json", "overwrite", 
path=tmpPath)
+        schema = StructType([StructField("value", StringType(), True)])
+        actual = self.spark.createExternalTable("externalJsonTable", 
source="json",
+                                                schema=schema, path=tmpPath,
+                                                noUse="this options will not 
be used")
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
savedJsonTable").collect()))
+        self.assertEqual(sorted(df.select("value").collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
externalJsonTable").collect()))
+        self.assertEqual(sorted(df.select("value").collect()), 
sorted(actual.collect()))
+        self.spark.sql("DROP TABLE savedJsonTable")
+        self.spark.sql("DROP TABLE externalJsonTable")
+
+        defaultDataSourceName = self.spark.getConf("spark.sql.sources.default",
+                                                   
"org.apache.spark.sql.parquet")
+        self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
+        df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+        actual = self.spark.createExternalTable("externalJsonTable", 
path=tmpPath)
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
savedJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.spark.sql("SELECT * FROM 
externalJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+        self.spark.sql("DROP TABLE savedJsonTable")
+        self.spark.sql("DROP TABLE externalJsonTable")
+        self.spark.sql("SET spark.sql.sources.default=" + 
defaultDataSourceName)
+
+        shutil.rmtree(tmpPath)
+
+    def test_window_functions(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        w = Window.partitionBy("value").orderBy("key")
+        from pyspark.sql import functions as F
+        sel = df.select(df.value, df.key,
+                        F.max("key").over(w.rowsBetween(0, 1)),
+                        F.min("key").over(w.rowsBetween(0, 1)),
+                        F.count("key").over(w.rowsBetween(float('-inf'), 
float('inf'))),
+                        F.row_number().over(w),
+                        F.rank().over(w),
+                        F.dense_rank().over(w),
+                        F.ntile(2).over(w))
+        rs = sorted(sel.collect())
+        expected = [
+            ("1", 1, 1, 1, 1, 1, 1, 1, 1),
+            ("2", 1, 1, 1, 3, 1, 1, 1, 1),
+            ("2", 1, 2, 1, 3, 2, 1, 1, 1),
+            ("2", 2, 2, 2, 3, 3, 3, 2, 2)
+        ]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
+    def test_window_functions_without_partitionBy(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        w = Window.orderBy("key", df.value)
+        from pyspark.sql import functions as F
+        sel = df.select(df.value, df.key,
+                        F.max("key").over(w.rowsBetween(0, 1)),
+                        F.min("key").over(w.rowsBetween(0, 1)),
+                        F.count("key").over(w.rowsBetween(float('-inf'), 
float('inf'))),
+                        F.row_number().over(w),
+                        F.rank().over(w),
+                        F.dense_rank().over(w),
+                        F.ntile(2).over(w))
+        rs = sorted(sel.collect())
+        expected = [
+            ("1", 1, 1, 1, 4, 1, 1, 1, 1),
+            ("2", 1, 1, 1, 4, 2, 2, 2, 1),
+            ("2", 1, 2, 1, 4, 3, 2, 2, 2),
+            ("2", 2, 2, 2, 4, 4, 4, 3, 2)
+        ]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
+    def test_window_functions_cumulative_sum(self):
+        df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", 
"value"])
+        from pyspark.sql import functions as F
+
+        # Test cumulative sum
+        sel = df.select(
+            df.key,
+            F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 
0)))
+        rs = sorted(sel.collect())
+        expected = [("one", 1), ("two", 3)]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
+        # Test boundary values less than JVM's Long.MinValue and make sure we 
don't overflow
+        sel = df.select(
+            df.key,
+            F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding 
- 1, 0)))
+        rs = sorted(sel.collect())
+        expected = [("one", 1), ("two", 3)]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
+        # Test boundary values greater than JVM's Long.MaxValue and make sure 
we don't overflow
+        frame_end = Window.unboundedFollowing + 1
+        sel = df.select(
+            df.key,
+            F.sum(df.value).over(Window.rowsBetween(Window.currentRow, 
frame_end)))
+        rs = sorted(sel.collect())
+        expected = [("one", 3), ("two", 2)]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
+    def test_collect_functions(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        from pyspark.sql import functions
+
+        self.assertEqual(
+            
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
+            [1, 2])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
+            [1, 1, 1, 2])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
+            ["1", "2"])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
+            ["1", "2", "2", "2"])
+
+    def test_limit_and_take(self):
+        df = self.spark.range(1, 1000, numPartitions=10)
+
+        def assert_runs_only_one_job_stage_and_task(job_group_name, f):
+            tracker = self.sc.statusTracker()
+            self.sc.setJobGroup(job_group_name, description="")
+            f()
+            jobs = tracker.getJobIdsForGroup(job_group_name)
+            self.assertEqual(1, len(jobs))
+            stages = tracker.getJobInfo(jobs[0]).stageIds
+            self.assertEqual(1, len(stages))
+            self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
+
+        # Regression test for SPARK-10731: take should delegate to Scala 
implementation
+        assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
+        # Regression test for SPARK-17514: limit(n).collect() should the 
perform same as take(n)
+        assert_runs_only_one_job_stage_and_task("collect_limit", lambda: 
df.limit(1).collect())
+
+    def test_datetime_functions(self):
+        from pyspark.sql import functions
+        from datetime import date
+        df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
+        parse_result = 
df.select(functions.to_date(functions.col("dateCol"))).first()
+        self.assertEquals(date(2017, 1, 22), 
parse_result['to_date(`dateCol`)'])
+
+    @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't 
support mocking")
+    def test_unbounded_frames(self):
+        from unittest.mock import patch
+        from pyspark.sql import functions as F
+        from pyspark.sql import window
+        import importlib
+
+        df = self.spark.range(0, 3)
+
+        def rows_frame_match():
+            return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" 
in df.select(
+                F.count("*").over(window.Window.rowsBetween(-sys.maxsize, 
sys.maxsize))
+            ).columns[0]
+
+        def range_frame_match():
+            return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" 
in df.select(
+                F.count("*").over(window.Window.rangeBetween(-sys.maxsize, 
sys.maxsize))
+            ).columns[0]
+
+        with patch("sys.maxsize", 2 ** 31 - 1):
+            importlib.reload(window)
+            self.assertTrue(rows_frame_match())
+            self.assertTrue(range_frame_match())
+
+        with patch("sys.maxsize", 2 ** 63 - 1):
+            importlib.reload(window)
+            self.assertTrue(rows_frame_match())
+            self.assertTrue(range_frame_match())
+
+        with patch("sys.maxsize", 2 ** 127 - 1):
+            importlib.reload(window)
+            self.assertTrue(rows_frame_match())
+            self.assertTrue(range_frame_match())
+
+        importlib.reload(window)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_context import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to