http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
new file mode 100644
index 0000000..b811047
--- /dev/null
+++ b/python/pyspark/sql/tests/test_session.py
@@ -0,0 +1,320 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.sql import SparkSession, SQLContext, Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.tests import PySparkTestCase
+
+
+class SparkSessionTests(ReusedSQLTestCase):
+    def test_sqlcontext_reuses_sparksession(self):
+        sqlContext1 = SQLContext(self.sc)
+        sqlContext2 = SQLContext(self.sc)
+        self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
+
+
+class SparkSessionTests1(ReusedSQLTestCase):
+
+    # We can't include this test into SQLTests because we will stop class's 
SparkContext and cause
+    # other tests failed.
+    def test_sparksession_with_stopped_sparkcontext(self):
+        self.sc.stop()
+        sc = SparkContext('local[4]', self.sc.appName)
+        spark = SparkSession.builder.getOrCreate()
+        try:
+            df = spark.createDataFrame([(1, 2)], ["c", "c"])
+            df.collect()
+        finally:
+            spark.stop()
+            sc.stop()
+
+
+class SparkSessionTests2(PySparkTestCase):
+
+    # This test is separate because it's closely related with session's start 
and stop.
+    # See SPARK-23228.
+    def test_set_jvm_default_session(self):
+        spark = SparkSession.builder.getOrCreate()
+        try:
+            
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
+        finally:
+            spark.stop()
+            
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
+
+    def test_jvm_default_session_already_set(self):
+        # Here, we assume there is the default session already set in JVM.
+        jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
+        self.sc._jvm.SparkSession.setDefaultSession(jsession)
+
+        spark = SparkSession.builder.getOrCreate()
+        try:
+            
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
+            # The session should be the same with the exiting one.
+            
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
+        finally:
+            spark.stop()
+
+
+class SparkSessionTests3(unittest.TestCase):
+
+    def test_active_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            activeSession = SparkSession.getActiveSession()
+            df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
+        finally:
+            spark.stop()
+
+    def test_get_active_session_when_no_active_session(self):
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, None)
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, spark)
+        spark.stop()
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, None)
+
+    def test_SparkSession(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .config("some-config", "v2") \
+            .getOrCreate()
+        try:
+            self.assertEqual(spark.conf.get("some-config"), "v2")
+            self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
+            self.assertEqual(spark.version, spark.sparkContext.version)
+            spark.sql("CREATE DATABASE test_db")
+            spark.catalog.setCurrentDatabase("test_db")
+            self.assertEqual(spark.catalog.currentDatabase(), "test_db")
+            spark.sql("CREATE TABLE table1 (name STRING, age INT) USING 
parquet")
+            self.assertEqual(spark.table("table1").columns, ['name', 'age'])
+            self.assertEqual(spark.range(3).count(), 3)
+        finally:
+            spark.stop()
+
+    def test_global_default_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            self.assertEqual(SparkSession.builder.getOrCreate(), spark)
+        finally:
+            spark.stop()
+
+    def test_default_and_active_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        activeSession = spark._jvm.SparkSession.getActiveSession()
+        defaultSession = spark._jvm.SparkSession.getDefaultSession()
+        try:
+            self.assertEqual(activeSession, defaultSession)
+        finally:
+            spark.stop()
+
+    def test_config_option_propagated_to_existing_session(self):
+        session1 = SparkSession.builder \
+            .master("local") \
+            .config("spark-config1", "a") \
+            .getOrCreate()
+        self.assertEqual(session1.conf.get("spark-config1"), "a")
+        session2 = SparkSession.builder \
+            .config("spark-config1", "b") \
+            .getOrCreate()
+        try:
+            self.assertEqual(session1, session2)
+            self.assertEqual(session1.conf.get("spark-config1"), "b")
+        finally:
+            session1.stop()
+
+    def test_new_session(self):
+        session = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        newSession = session.newSession()
+        try:
+            self.assertNotEqual(session, newSession)
+        finally:
+            session.stop()
+            newSession.stop()
+
+    def test_create_new_session_if_old_session_stopped(self):
+        session = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        session.stop()
+        newSession = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            self.assertNotEqual(session, newSession)
+        finally:
+            newSession.stop()
+
+    def test_active_session_with_None_and_not_None_context(self):
+        from pyspark.context import SparkContext
+        from pyspark.conf import SparkConf
+        sc = None
+        session = None
+        try:
+            sc = SparkContext._active_spark_context
+            self.assertEqual(sc, None)
+            activeSession = SparkSession.getActiveSession()
+            self.assertEqual(activeSession, None)
+            sparkConf = SparkConf()
+            sc = SparkContext.getOrCreate(sparkConf)
+            activeSession = sc._jvm.SparkSession.getActiveSession()
+            self.assertFalse(activeSession.isDefined())
+            session = SparkSession(sc)
+            activeSession = sc._jvm.SparkSession.getActiveSession()
+            self.assertTrue(activeSession.isDefined())
+            activeSession2 = SparkSession.getActiveSession()
+            self.assertNotEqual(activeSession2, None)
+        finally:
+            if session is not None:
+                session.stop()
+            if sc is not None:
+                sc.stop()
+
+
+class SparkSessionTests4(ReusedSQLTestCase):
+
+    def test_get_active_session_after_create_dataframe(self):
+        session2 = None
+        try:
+            activeSession1 = SparkSession.getActiveSession()
+            session1 = self.spark
+            self.assertEqual(session1, activeSession1)
+            session2 = self.spark.newSession()
+            activeSession2 = SparkSession.getActiveSession()
+            self.assertEqual(session1, activeSession2)
+            self.assertNotEqual(session2, activeSession2)
+            session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            activeSession3 = SparkSession.getActiveSession()
+            self.assertEqual(session2, activeSession3)
+            session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            activeSession4 = SparkSession.getActiveSession()
+            self.assertEqual(session1, activeSession4)
+        finally:
+            if session2 is not None:
+                session2.stop()
+
+
+class SparkSessionBuilderTests(unittest.TestCase):
+
+    def test_create_spark_context_first_then_spark_session(self):
+        sc = None
+        session = None
+        try:
+            conf = SparkConf().set("key1", "value1")
+            sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
+            session = SparkSession.builder.config("key2", 
"value2").getOrCreate()
+
+            self.assertEqual(session.conf.get("key1"), "value1")
+            self.assertEqual(session.conf.get("key2"), "value2")
+            self.assertEqual(session.sparkContext, sc)
+
+            self.assertFalse(sc.getConf().contains("key2"))
+            self.assertEqual(sc.getConf().get("key1"), "value1")
+        finally:
+            if session is not None:
+                session.stop()
+            if sc is not None:
+                sc.stop()
+
+    def test_another_spark_session(self):
+        session1 = None
+        session2 = None
+        try:
+            session1 = SparkSession.builder.config("key1", 
"value1").getOrCreate()
+            session2 = SparkSession.builder.config("key2", 
"value2").getOrCreate()
+
+            self.assertEqual(session1.conf.get("key1"), "value1")
+            self.assertEqual(session2.conf.get("key1"), "value1")
+            self.assertEqual(session1.conf.get("key2"), "value2")
+            self.assertEqual(session2.conf.get("key2"), "value2")
+            self.assertEqual(session1.sparkContext, session2.sparkContext)
+
+            self.assertEqual(session1.sparkContext.getConf().get("key1"), 
"value1")
+            self.assertFalse(session1.sparkContext.getConf().contains("key2"))
+        finally:
+            if session1 is not None:
+                session1.stop()
+            if session2 is not None:
+                session2.stop()
+
+
+class SparkExtensionsTest(unittest.TestCase):
+    # These tests are separate because it uses 'spark.sql.extensions' which is
+    # static and immutable. This can't be set or unset, for example, via 
`spark.conf`.
+
+    @classmethod
+    def setUpClass(cls):
+        import glob
+        from pyspark.find_spark_home import _find_spark_home
+
+        SPARK_HOME = _find_spark_home()
+        filename_pattern = (
+            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+            "SparkSessionExtensionSuite.class")
+        if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
+            raise unittest.SkipTest(
+                "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
+                "available. Will skip the related tests.")
+
+        # Note that 'spark.sql.extensions' is a static immutable configuration.
+        cls.spark = SparkSession.builder \
+            .master("local[4]") \
+            .appName(cls.__name__) \
+            .config(
+                "spark.sql.extensions",
+                "org.apache.spark.sql.MyExtensions") \
+            .getOrCreate()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.spark.stop()
+
+    def test_use_custom_class_for_extensions(self):
+        self.assertTrue(
+            
self.spark._jsparkSession.sessionState().planner().strategies().contains(
+                
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
+            "MySparkStrategy not found in active planner strategies")
+        self.assertTrue(
+            
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
+                
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
+            "MyRule not found in extended resolution rules")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_session import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_streaming.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_streaming.py 
b/python/pyspark/sql/tests/test_streaming.py
new file mode 100644
index 0000000..cc0cab4
--- /dev/null
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -0,0 +1,566 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import time
+
+from pyspark.sql.functions import lit
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class StreamingTests(ReusedSQLTestCase):
+
+    def test_stream_trigger(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+
+        # Should take at least one arg
+        try:
+            df.writeStream.trigger()
+        except ValueError:
+            pass
+
+        # Should not take multiple args
+        try:
+            df.writeStream.trigger(once=True, processingTime='5 seconds')
+        except ValueError:
+            pass
+
+        # Should not take multiple args
+        try:
+            df.writeStream.trigger(processingTime='5 seconds', continuous='1 
second')
+        except ValueError:
+            pass
+
+        # Should take only keyword args
+        try:
+            df.writeStream.trigger('5 seconds')
+            self.fail("Should have thrown an exception")
+        except TypeError:
+            pass
+
+    def test_stream_read_options(self):
+        schema = StructType([StructField("data", StringType(), False)])
+        df = self.spark.readStream\
+            .format('text')\
+            .option('path', 'python/test_support/sql/streaming')\
+            .schema(schema)\
+            .load()
+        self.assertTrue(df.isStreaming)
+        self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+    def test_stream_read_options_overwrite(self):
+        bad_schema = StructType([StructField("test", IntegerType(), False)])
+        schema = StructType([StructField("data", StringType(), False)])
+        df = self.spark.readStream.format('csv').option('path', 
'python/test_support/sql/fake') \
+            .schema(bad_schema)\
+            .load(path='python/test_support/sql/streaming', schema=schema, 
format='text')
+        self.assertTrue(df.isStreaming)
+        self.assertEqual(df.schema.simpleString(), "struct<data:string>")
+
+    def test_stream_save_options(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
+            .withColumn('id', lit(1))
+        for q in self.spark._wrapped.streams.active:
+            q.stop()
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        self.assertTrue(df.isStreaming)
+        out = os.path.join(tmpPath, 'out')
+        chk = os.path.join(tmpPath, 'chk')
+        q = df.writeStream.option('checkpointLocation', 
chk).queryName('this_query') \
+            
.format('parquet').partitionBy('id').outputMode('append').option('path', 
out).start()
+        try:
+            self.assertEqual(q.name, 'this_query')
+            self.assertTrue(q.isActive)
+            q.processAllAvailable()
+            output_files = []
+            for _, _, files in os.walk(out):
+                output_files.extend([f for f in files if not 
f.startswith('.')])
+            self.assertTrue(len(output_files) > 0)
+            self.assertTrue(len(os.listdir(chk)) > 0)
+        finally:
+            q.stop()
+            shutil.rmtree(tmpPath)
+
+    def test_stream_save_options_overwrite(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+        for q in self.spark._wrapped.streams.active:
+            q.stop()
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        self.assertTrue(df.isStreaming)
+        out = os.path.join(tmpPath, 'out')
+        chk = os.path.join(tmpPath, 'chk')
+        fake1 = os.path.join(tmpPath, 'fake1')
+        fake2 = os.path.join(tmpPath, 'fake2')
+        q = df.writeStream.option('checkpointLocation', fake1)\
+            .format('memory').option('path', fake2) \
+            .queryName('fake_query').outputMode('append') \
+            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
+
+        try:
+            self.assertEqual(q.name, 'this_query')
+            self.assertTrue(q.isActive)
+            q.processAllAvailable()
+            output_files = []
+            for _, _, files in os.walk(out):
+                output_files.extend([f for f in files if not 
f.startswith('.')])
+            self.assertTrue(len(output_files) > 0)
+            self.assertTrue(len(os.listdir(chk)) > 0)
+            self.assertFalse(os.path.isdir(fake1))  # should not have been 
created
+            self.assertFalse(os.path.isdir(fake2))  # should not have been 
created
+        finally:
+            q.stop()
+            shutil.rmtree(tmpPath)
+
+    def test_stream_status_and_progress(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+        for q in self.spark._wrapped.streams.active:
+            q.stop()
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        self.assertTrue(df.isStreaming)
+        out = os.path.join(tmpPath, 'out')
+        chk = os.path.join(tmpPath, 'chk')
+
+        def func(x):
+            time.sleep(1)
+            return x
+
+        from pyspark.sql.functions import col, udf
+        sleep_udf = udf(func)
+
+        # Use "sleep_udf" to delay the progress update so that we can test 
`lastProgress` when there
+        # were no updates.
+        q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
+            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
+        try:
+            # "lastProgress" will return None in most cases. However, as it 
may be flaky when
+            # Jenkins is very slow, we don't assert it. If there is something 
wrong, "lastProgress"
+            # may throw error with a high chance and make this test flaky, so 
we should still be
+            # able to detect broken codes.
+            q.lastProgress
+
+            q.processAllAvailable()
+            lastProgress = q.lastProgress
+            recentProgress = q.recentProgress
+            status = q.status
+            self.assertEqual(lastProgress['name'], q.name)
+            self.assertEqual(lastProgress['id'], q.id)
+            self.assertTrue(any(p == lastProgress for p in recentProgress))
+            self.assertTrue(
+                "message" in status and
+                "isDataAvailable" in status and
+                "isTriggerActive" in status)
+        finally:
+            q.stop()
+            shutil.rmtree(tmpPath)
+
+    def test_stream_await_termination(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+        for q in self.spark._wrapped.streams.active:
+            q.stop()
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        self.assertTrue(df.isStreaming)
+        out = os.path.join(tmpPath, 'out')
+        chk = os.path.join(tmpPath, 'chk')
+        q = df.writeStream\
+            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
+        try:
+            self.assertTrue(q.isActive)
+            try:
+                q.awaitTermination("hello")
+                self.fail("Expected a value exception")
+            except ValueError:
+                pass
+            now = time.time()
+            # test should take at least 2 seconds
+            res = q.awaitTermination(2.6)
+            duration = time.time() - now
+            self.assertTrue(duration >= 2)
+            self.assertFalse(res)
+        finally:
+            q.stop()
+            shutil.rmtree(tmpPath)
+
+    def test_stream_exception(self):
+        sdf = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+        sq = 
sdf.writeStream.format('memory').queryName('query_explain').start()
+        try:
+            sq.processAllAvailable()
+            self.assertEqual(sq.exception(), None)
+        finally:
+            sq.stop()
+
+        from pyspark.sql.functions import col, udf
+        from pyspark.sql.utils import StreamingQueryException
+        bad_udf = udf(lambda x: 1 / 0)
+        sq = sdf.select(bad_udf(col("value")))\
+            .writeStream\
+            .format('memory')\
+            .queryName('this_query')\
+            .start()
+        try:
+            # Process some data to fail the query
+            sq.processAllAvailable()
+            self.fail("bad udf should fail the query")
+        except StreamingQueryException as e:
+            # This is expected
+            self.assertTrue("ZeroDivisionError" in e.desc)
+        finally:
+            sq.stop()
+        self.assertTrue(type(sq.exception()) is StreamingQueryException)
+        self.assertTrue("ZeroDivisionError" in sq.exception().desc)
+
+    def test_query_manager_await_termination(self):
+        df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+        for q in self.spark._wrapped.streams.active:
+            q.stop()
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        self.assertTrue(df.isStreaming)
+        out = os.path.join(tmpPath, 'out')
+        chk = os.path.join(tmpPath, 'chk')
+        q = df.writeStream\
+            .start(path=out, format='parquet', queryName='this_query', 
checkpointLocation=chk)
+        try:
+            self.assertTrue(q.isActive)
+            try:
+                self.spark._wrapped.streams.awaitAnyTermination("hello")
+                self.fail("Expected a value exception")
+            except ValueError:
+                pass
+            now = time.time()
+            # test should take at least 2 seconds
+            res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
+            duration = time.time() - now
+            self.assertTrue(duration >= 2)
+            self.assertFalse(res)
+        finally:
+            q.stop()
+            shutil.rmtree(tmpPath)
+
+    class ForeachWriterTester:
+
+        def __init__(self, spark):
+            self.spark = spark
+
+        def write_open_event(self, partitionId, epochId):
+            self._write_event(
+                self.open_events_dir,
+                {'partition': partitionId, 'epoch': epochId})
+
+        def write_process_event(self, row):
+            self._write_event(self.process_events_dir, {'value': 'text'})
+
+        def write_close_event(self, error):
+            self._write_event(self.close_events_dir, {'error': str(error)})
+
+        def write_input_file(self):
+            self._write_event(self.input_dir, "text")
+
+        def open_events(self):
+            return self._read_events(self.open_events_dir, 'partition INT, 
epoch INT')
+
+        def process_events(self):
+            return self._read_events(self.process_events_dir, 'value STRING')
+
+        def close_events(self):
+            return self._read_events(self.close_events_dir, 'error STRING')
+
+        def run_streaming_query_on_writer(self, writer, num_files):
+            self._reset()
+            try:
+                sdf = self.spark.readStream.format('text').load(self.input_dir)
+                sq = sdf.writeStream.foreach(writer).start()
+                for i in range(num_files):
+                    self.write_input_file()
+                    sq.processAllAvailable()
+            finally:
+                self.stop_all()
+
+        def assert_invalid_writer(self, writer, msg=None):
+            self._reset()
+            try:
+                sdf = self.spark.readStream.format('text').load(self.input_dir)
+                sq = sdf.writeStream.foreach(writer).start()
+                self.write_input_file()
+                sq.processAllAvailable()
+                self.fail("invalid writer %s did not fail the query" % 
str(writer))  # not expected
+            except Exception as e:
+                if msg:
+                    assert msg in str(e), "%s not in %s" % (msg, str(e))
+
+            finally:
+                self.stop_all()
+
+        def stop_all(self):
+            for q in self.spark._wrapped.streams.active:
+                q.stop()
+
+        def _reset(self):
+            self.input_dir = tempfile.mkdtemp()
+            self.open_events_dir = tempfile.mkdtemp()
+            self.process_events_dir = tempfile.mkdtemp()
+            self.close_events_dir = tempfile.mkdtemp()
+
+        def _read_events(self, dir, json):
+            rows = self.spark.read.schema(json).json(dir).collect()
+            dicts = [row.asDict() for row in rows]
+            return dicts
+
+        def _write_event(self, dir, event):
+            import uuid
+            with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
+                f.write("%s\n" % str(event))
+
+        def __getstate__(self):
+            return (self.open_events_dir, self.process_events_dir, 
self.close_events_dir)
+
+        def __setstate__(self, state):
+            self.open_events_dir, self.process_events_dir, 
self.close_events_dir = state
+
+    # Those foreach tests are failed in Python 3.6 and macOS High Sierra by 
defined rules
+    # at 
http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
+    # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES.
+    def test_streaming_foreach_with_simple_function(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        def foreach_func(row):
+            tester.write_process_event(row)
+
+        tester.run_streaming_query_on_writer(foreach_func, 2)
+        self.assertEqual(len(tester.process_events()), 2)
+
+    def test_streaming_foreach_with_basic_open_process_close(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def open(self, partitionId, epochId):
+                tester.write_open_event(partitionId, epochId)
+                return True
+
+            def process(self, row):
+                tester.write_process_event(row)
+
+            def close(self, error):
+                tester.write_close_event(error)
+
+        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+        open_events = tester.open_events()
+        self.assertEqual(len(open_events), 2)
+        self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+        self.assertEqual(len(tester.process_events()), 2)
+
+        close_events = tester.close_events()
+        self.assertEqual(len(close_events), 2)
+        self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+    def test_streaming_foreach_with_open_returning_false(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def open(self, partition_id, epoch_id):
+                tester.write_open_event(partition_id, epoch_id)
+                return False
+
+            def process(self, row):
+                tester.write_process_event(row)
+
+            def close(self, error):
+                tester.write_close_event(error)
+
+        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+        self.assertEqual(len(tester.open_events()), 2)
+
+        self.assertEqual(len(tester.process_events()), 0)  # no row was 
processed
+
+        close_events = tester.close_events()
+        self.assertEqual(len(close_events), 2)
+        self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+    def test_streaming_foreach_without_open_method(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def process(self, row):
+                tester.write_process_event(row)
+
+            def close(self, error):
+                tester.write_close_event(error)
+
+        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+        self.assertEqual(len(tester.open_events()), 0)  # no open events
+        self.assertEqual(len(tester.process_events()), 2)
+        self.assertEqual(len(tester.close_events()), 2)
+
+    def test_streaming_foreach_without_close_method(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def open(self, partition_id, epoch_id):
+                tester.write_open_event(partition_id, epoch_id)
+                return True
+
+            def process(self, row):
+                tester.write_process_event(row)
+
+        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+        self.assertEqual(len(tester.open_events()), 2)  # no open events
+        self.assertEqual(len(tester.process_events()), 2)
+        self.assertEqual(len(tester.close_events()), 0)
+
+    def test_streaming_foreach_without_open_and_close_methods(self):
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def process(self, row):
+                tester.write_process_event(row)
+
+        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+        self.assertEqual(len(tester.open_events()), 0)  # no open events
+        self.assertEqual(len(tester.process_events()), 2)
+        self.assertEqual(len(tester.close_events()), 0)
+
+    def test_streaming_foreach_with_process_throwing_error(self):
+        from pyspark.sql.utils import StreamingQueryException
+
+        tester = self.ForeachWriterTester(self.spark)
+
+        class ForeachWriter:
+            def process(self, row):
+                raise Exception("test error")
+
+            def close(self, error):
+                tester.write_close_event(error)
+
+        try:
+            tester.run_streaming_query_on_writer(ForeachWriter(), 1)
+            self.fail("bad writer did not fail the query")  # this is not 
expected
+        except StreamingQueryException as e:
+            # TODO: Verify whether original error message is inside the 
exception
+            pass
+
+        self.assertEqual(len(tester.process_events()), 0)  # no row was 
processed
+        close_events = tester.close_events()
+        self.assertEqual(len(close_events), 1)
+        # TODO: Verify whether original error message is inside the exception
+
+    def test_streaming_foreach_with_invalid_writers(self):
+
+        tester = self.ForeachWriterTester(self.spark)
+
+        def func_with_iterator_input(iter):
+            for x in iter:
+                print(x)
+
+        tester.assert_invalid_writer(func_with_iterator_input)
+
+        class WriterWithoutProcess:
+            def open(self, partition):
+                pass
+
+        tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 
'process'")
+
+        class WriterWithNonCallableProcess():
+            process = True
+
+        tester.assert_invalid_writer(WriterWithNonCallableProcess(),
+                                     "'process' in provided object is not 
callable")
+
+        class WriterWithNoParamProcess():
+            def process(self):
+                pass
+
+        tester.assert_invalid_writer(WriterWithNoParamProcess())
+
+        # Abstract class for tests below
+        class WithProcess():
+            def process(self, row):
+                pass
+
+        class WriterWithNonCallableOpen(WithProcess):
+            open = True
+
+        tester.assert_invalid_writer(WriterWithNonCallableOpen(),
+                                     "'open' in provided object is not 
callable")
+
+        class WriterWithNoParamOpen(WithProcess):
+            def open(self):
+                pass
+
+        tester.assert_invalid_writer(WriterWithNoParamOpen())
+
+        class WriterWithNonCallableClose(WithProcess):
+            close = True
+
+        tester.assert_invalid_writer(WriterWithNonCallableClose(),
+                                     "'close' in provided object is not 
callable")
+
+    def test_streaming_foreachBatch(self):
+        q = None
+        collected = dict()
+
+        def collectBatch(batch_df, batch_id):
+            collected[batch_id] = batch_df.collect()
+
+        try:
+            df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+            q = df.writeStream.foreachBatch(collectBatch).start()
+            q.processAllAvailable()
+            self.assertTrue(0 in collected)
+            self.assertTrue(len(collected[0]), 2)
+        finally:
+            if q:
+                q.stop()
+
+    def test_streaming_foreachBatch_propagates_python_errors(self):
+        from pyspark.sql.utils import StreamingQueryException
+
+        q = None
+
+        def collectBatch(df, id):
+            raise Exception("this should fail the query")
+
+        try:
+            df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+            q = df.writeStream.foreachBatch(collectBatch).start()
+            q.processAllAvailable()
+            self.fail("Expected a failure")
+        except StreamingQueryException as e:
+            self.assertTrue("this should fail" in str(e))
+        finally:
+            if q:
+                q.stop()
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_streaming import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
new file mode 100644
index 0000000..3b32c58
--- /dev/null
+++ b/python/pyspark/sql/tests/test_types.py
@@ -0,0 +1,944 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import array
+import ctypes
+import datetime
+import os
+import pickle
+import sys
+import unittest
+
+from pyspark.sql import Row
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.types import *
+from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, 
_array_type_mappings, \
+    _array_unsigned_int_typecode_ctype_mappings, _infer_type, 
_make_type_verifier, _merge_type
+from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, 
PythonOnlyUDT, \
+    ExamplePoint, PythonOnlyPoint, MyObject
+
+
+class TypesTests(ReusedSQLTestCase):
+
+    def test_apply_schema_to_row(self):
+        df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
+        df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
+        self.assertEqual(df.collect(), df2.collect())
+
+        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+        df3 = self.spark.createDataFrame(rdd, df.schema)
+        self.assertEqual(10, df3.count())
+
+    def test_infer_schema_to_local(self):
+        input = [{"a": 1}, {"b": "coffee"}]
+        rdd = self.sc.parallelize(input)
+        df = self.spark.createDataFrame(input)
+        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
+        self.assertEqual(df.schema, df2.schema)
+
+        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+        df3 = self.spark.createDataFrame(rdd, df.schema)
+        self.assertEqual(10, df3.count())
+
+    def test_apply_schema_to_dict_and_rows(self):
+        schema = StructType().add("b", StringType()).add("a", IntegerType())
+        input = [{"a": 1}, {"b": "coffee"}]
+        rdd = self.sc.parallelize(input)
+        for verify in [False, True]:
+            df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+            df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+            self.assertEqual(df.schema, df2.schema)
+
+            rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, 
b=None))
+            df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+            self.assertEqual(10, df3.count())
+            input = [Row(a=x, b=str(x)) for x in range(10)]
+            df4 = self.spark.createDataFrame(input, schema, 
verifySchema=verify)
+            self.assertEqual(10, df4.count())
+
+    def test_create_dataframe_schema_mismatch(self):
+        input = [Row(a=1)]
+        rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
+        schema = StructType([StructField("a", IntegerType()), StructField("b", 
StringType())])
+        df = self.spark.createDataFrame(rdd, schema)
+        self.assertRaises(Exception, lambda: df.show())
+
+    def test_infer_schema(self):
+        d = [Row(l=[], d={}, s=None),
+             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+        rdd = self.sc.parallelize(d)
+        df = self.spark.createDataFrame(rdd)
+        self.assertEqual([], df.rdd.map(lambda r: r.l).first())
+        self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
+
+        with self.tempView("test"):
+            df.createOrReplaceTempView("test")
+            result = self.spark.sql("SELECT l[0].a from test where d['key'].d 
= '2'")
+            self.assertEqual(1, result.head()[0])
+
+        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
+        self.assertEqual(df.schema, df2.schema)
+        self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
+        self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
+
+        with self.tempView("test2"):
+            df2.createOrReplaceTempView("test2")
+            result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d 
= '2'")
+            self.assertEqual(1, result.head()[0])
+
+    def test_infer_schema_specification(self):
+        from decimal import Decimal
+
+        class A(object):
+            def __init__(self):
+                self.a = 1
+
+        data = [
+            True,
+            1,
+            "a",
+            u"a",
+            datetime.date(1970, 1, 1),
+            datetime.datetime(1970, 1, 1, 0, 0),
+            1.0,
+            array.array("d", [1]),
+            [1],
+            (1, ),
+            {"a": 1},
+            bytearray(1),
+            Decimal(1),
+            Row(a=1),
+            Row("a")(1),
+            A(),
+        ]
+
+        df = self.spark.createDataFrame([data])
+        actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
+        expected = [
+            'boolean',
+            'bigint',
+            'string',
+            'string',
+            'date',
+            'timestamp',
+            'double',
+            'array<double>',
+            'array<bigint>',
+            'struct<_1:bigint>',
+            'map<string,bigint>',
+            'binary',
+            'decimal(38,18)',
+            'struct<a:bigint>',
+            'struct<a:bigint>',
+            'struct<a:bigint>',
+        ]
+        self.assertEqual(actual, expected)
+
+        actual = list(df.first())
+        expected = [
+            True,
+            1,
+            'a',
+            u"a",
+            datetime.date(1970, 1, 1),
+            datetime.datetime(1970, 1, 1, 0, 0),
+            1.0,
+            [1.0],
+            [1],
+            Row(_1=1),
+            {"a": 1},
+            bytearray(b'\x00'),
+            Decimal('1.000000000000000000'),
+            Row(a=1),
+            Row(a=1),
+            Row(a=1),
+        ]
+        self.assertEqual(actual, expected)
+
+    def test_infer_schema_not_enough_names(self):
+        df = self.spark.createDataFrame([["a", "b"]], ["col1"])
+        self.assertEqual(df.columns, ['col1', '_2'])
+
+    def test_infer_schema_fails(self):
+        with self.assertRaisesRegexp(TypeError, 'field a'):
+            
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 
1]]),
+                                       schema=["a", "b"], samplingRatio=0.99)
+
+    def test_infer_nested_schema(self):
+        NestedRow = Row("f1", "f2")
+        nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+                                          NestedRow([2, 3], {"row2": 2.0})])
+        df = self.spark.createDataFrame(nestedRdd1)
+        self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+        nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+                                          NestedRow([[2, 3], [3, 4]], [2, 3])])
+        df = self.spark.createDataFrame(nestedRdd2)
+        self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+        from collections import namedtuple
+        CustomRow = namedtuple('CustomRow', 'field1 field2')
+        rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+                                   CustomRow(field1=2, field2="row2"),
+                                   CustomRow(field1=3, field2="row3")])
+        df = self.spark.createDataFrame(rdd)
+        self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
+
+    def test_create_dataframe_from_dict_respects_schema(self):
+        df = self.spark.createDataFrame([{'a': 1}], ["b"])
+        self.assertEqual(df.columns, ['b'])
+
+    def test_create_dataframe_from_objects(self):
+        data = [MyObject(1, "1"), MyObject(2, "2")]
+        df = self.spark.createDataFrame(data)
+        self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
+        self.assertEqual(df.first(), Row(key=1, value="1"))
+
+    def test_apply_schema(self):
+        from datetime import date, datetime
+        rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
+                                    date(2010, 1, 1), datetime(2010, 1, 1, 1, 
1, 1),
+                                    {"a": 1}, (2,), [1, 2, 3], None)])
+        schema = StructType([
+            StructField("byte1", ByteType(), False),
+            StructField("byte2", ByteType(), False),
+            StructField("short1", ShortType(), False),
+            StructField("short2", ShortType(), False),
+            StructField("int1", IntegerType(), False),
+            StructField("float1", FloatType(), False),
+            StructField("date1", DateType(), False),
+            StructField("time1", TimestampType(), False),
+            StructField("map1", MapType(StringType(), IntegerType(), False), 
False),
+            StructField("struct1", StructType([StructField("b", ShortType(), 
False)]), False),
+            StructField("list1", ArrayType(ByteType(), False), False),
+            StructField("null1", DoubleType(), True)])
+        df = self.spark.createDataFrame(rdd, schema)
+        results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, 
x.int1, x.float1,
+                             x.date1, x.time1, x.map1["a"], x.struct1.b, 
x.list1, x.null1))
+        r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+             datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+        self.assertEqual(r, results.first())
+
+        with self.tempView("table2"):
+            df.createOrReplaceTempView("table2")
+            r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, 
" +
+                               "short1 + 1 AS short1, short2 - 1 AS short2, 
int1 - 1 AS int1, " +
+                               "float1 + 1.5 as float1 FROM table2").first()
+
+            self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), 
tuple(r))
+
+    def test_convert_row_to_dict(self):
+        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+        self.assertEqual(1, row.asDict()['l'][0].a)
+        df = self.sc.parallelize([row]).toDF()
+
+        with self.tempView("test"):
+            df.createOrReplaceTempView("test")
+            row = self.spark.sql("select l, d from test").head()
+            self.assertEqual(1, row.asDict()["l"][0].a)
+            self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+    def test_udt(self):
+        from pyspark.sql.types import _parse_datatype_json_string, 
_infer_type, _make_type_verifier
+
+        def check_datatype(datatype):
+            pickled = pickle.loads(pickle.dumps(datatype))
+            assert datatype == pickled
+            scala_datatype = 
self.spark._jsparkSession.parseDataType(datatype.json())
+            python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
+            assert datatype == python_datatype
+
+        check_datatype(ExamplePointUDT())
+        structtype_with_udt = StructType([StructField("label", DoubleType(), 
False),
+                                          StructField("point", 
ExamplePointUDT(), False)])
+        check_datatype(structtype_with_udt)
+        p = ExamplePoint(1.0, 2.0)
+        self.assertEqual(_infer_type(p), ExamplePointUDT())
+        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
+        self.assertRaises(ValueError, lambda: 
_make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
+
+        check_datatype(PythonOnlyUDT())
+        structtype_with_udt = StructType([StructField("label", DoubleType(), 
False),
+                                          StructField("point", 
PythonOnlyUDT(), False)])
+        check_datatype(structtype_with_udt)
+        p = PythonOnlyPoint(1.0, 2.0)
+        self.assertEqual(_infer_type(p), PythonOnlyUDT())
+        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
+        self.assertRaises(
+            ValueError,
+            lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
+
+    def test_simple_udt_in_df(self):
+        schema = StructType().add("key", LongType()).add("val", 
PythonOnlyUDT())
+        df = self.spark.createDataFrame(
+            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+            schema=schema)
+        df.collect()
+
+    def test_nested_udt_in_df(self):
+        schema = StructType().add("key", LongType()).add("val", 
ArrayType(PythonOnlyUDT()))
+        df = self.spark.createDataFrame(
+            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in 
range(10)],
+            schema=schema)
+        df.collect()
+
+        schema = StructType().add("key", LongType()).add("val",
+                                                         MapType(LongType(), 
PythonOnlyUDT()))
+        df = self.spark.createDataFrame(
+            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for 
i in range(10)],
+            schema=schema)
+        df.collect()
+
+    def test_complex_nested_udt_in_df(self):
+        from pyspark.sql.functions import udf
+
+        schema = StructType().add("key", LongType()).add("val", 
PythonOnlyUDT())
+        df = self.spark.createDataFrame(
+            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+            schema=schema)
+        df.collect()
+
+        gd = df.groupby("key").agg({"val": "collect_list"})
+        gd.collect()
+        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+        gd.select(udf(*gd)).collect()
+
+    def test_udt_with_none(self):
+        df = self.spark.range(0, 10, 1, 1)
+
+        def myudf(x):
+            if x > 0:
+                return PythonOnlyPoint(float(x), float(x))
+
+        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+
+    def test_infer_schema_with_udt(self):
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        schema = df.schema
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), ExamplePointUDT)
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        schema = df.schema
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), PythonOnlyUDT)
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+    def test_apply_schema_with_udt(self):
+        row = (1.0, ExamplePoint(1.0, 2.0))
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        df = self.spark.createDataFrame([row], schema)
+        point = df.head().point
+        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+        row = (1.0, PythonOnlyPoint(1.0, 2.0))
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", PythonOnlyUDT(), False)])
+        df = self.spark.createDataFrame([row], schema)
+        point = df.head().point
+        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+    def test_udf_with_udt(self):
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), 
ExamplePointUDT())
+        self.assertEqual(ExamplePoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
+
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 
1), PythonOnlyUDT())
+        self.assertEqual(PythonOnlyPoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
+
+    def test_parquet_with_udt(self):
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        df0 = self.spark.createDataFrame([row])
+        output_dir = os.path.join(self.tempdir.name, "labeled_point")
+        df0.write.parquet(output_dir)
+        df1 = self.spark.read.parquet(output_dir)
+        point = df1.head().point
+        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df0 = self.spark.createDataFrame([row])
+        df0.write.parquet(output_dir, mode='overwrite')
+        df1 = self.spark.read.parquet(output_dir)
+        point = df1.head().point
+        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+    def test_union_with_udt(self):
+        row1 = (1.0, ExamplePoint(1.0, 2.0))
+        row2 = (2.0, ExamplePoint(3.0, 4.0))
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        df1 = self.spark.createDataFrame([row1], schema)
+        df2 = self.spark.createDataFrame([row2], schema)
+
+        result = df1.union(df2).orderBy("label").collect()
+        self.assertEqual(
+            result,
+            [
+                Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
+                Row(label=2.0, point=ExamplePoint(3.0, 4.0))
+            ]
+        )
+
+    def test_cast_to_string_with_udt(self):
+        from pyspark.sql.functions import col
+        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
+        schema = StructType([StructField("point", ExamplePointUDT(), False),
+                             StructField("pypoint", PythonOnlyUDT(), False)])
+        df = self.spark.createDataFrame([row], schema)
+
+        result = df.select(col('point').cast('string'), 
col('pypoint').cast('string')).head()
+        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 
4.0]'))
+
+    def test_struct_type(self):
+        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
+        struct2 = StructType([StructField("f1", StringType(), True),
+                              StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1.fieldNames(), struct2.names)
+        self.assertEqual(struct1, struct2)
+
+        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
+        struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1.fieldNames(), struct2.names)
+        self.assertNotEqual(struct1, struct2)
+
+        struct1 = (StructType().add(StructField("f1", StringType(), True))
+                   .add(StructField("f2", StringType(), True, None)))
+        struct2 = StructType([StructField("f1", StringType(), True),
+                              StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1.fieldNames(), struct2.names)
+        self.assertEqual(struct1, struct2)
+
+        struct1 = (StructType().add(StructField("f1", StringType(), True))
+                   .add(StructField("f2", StringType(), True, None)))
+        struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1.fieldNames(), struct2.names)
+        self.assertNotEqual(struct1, struct2)
+
+        # Catch exception raised during improper construction
+        self.assertRaises(ValueError, lambda: StructType().add("name"))
+
+        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
+        for field in struct1:
+            self.assertIsInstance(field, StructField)
+
+        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
+        self.assertEqual(len(struct1), 2)
+
+        struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
+        self.assertIs(struct1["f1"], struct1.fields[0])
+        self.assertIs(struct1[0], struct1.fields[0])
+        self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
+        self.assertRaises(KeyError, lambda: struct1["f9"])
+        self.assertRaises(IndexError, lambda: struct1[9])
+        self.assertRaises(TypeError, lambda: struct1[9.9])
+
+    def test_parse_datatype_string(self):
+        from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
+        for k, t in _all_atomic_types.items():
+            if t != NullType:
+                self.assertEqual(t(), _parse_datatype_string(k))
+        self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+        self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1  
,1)"))
+        self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 
10,1 )"))
+        self.assertEqual(DecimalType(11, 1), 
_parse_datatype_string("decimal(11,1)"))
+        self.assertEqual(
+            ArrayType(IntegerType()),
+            _parse_datatype_string("array<int >"))
+        self.assertEqual(
+            MapType(IntegerType(), DoubleType()),
+            _parse_datatype_string("map< int, double  >"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
+            _parse_datatype_string("struct<a:int, c:double >"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
+            _parse_datatype_string("a:int, c:double"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", 
DoubleType())]),
+            _parse_datatype_string("a INT, c DOUBLE"))
+
+    def test_metadata_null(self):
+        schema = StructType([StructField("f1", StringType(), True, None),
+                             StructField("f2", StringType(), True, {'a': 
None})])
+        rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
+        self.spark.createDataFrame(rdd, schema)
+
+    def test_access_nested_types(self):
+        df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": 
"v"})]).toDF()
+        self.assertEqual(1, df.select(df.l[0]).first()[0])
+        self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
+        self.assertEqual(1, df.select(df.r.a).first()[0])
+        self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
+        self.assertEqual("v", df.select(df.d["k"]).first()[0])
+        self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+
+    def test_infer_long_type(self):
+        longrow = [Row(f1='a', f2=100000000000000)]
+        df = self.sc.parallelize(longrow).toDF()
+        self.assertEqual(df.schema.fields[1].dataType, LongType())
+
+        # this saving as Parquet caused issues as well.
+        output_dir = os.path.join(self.tempdir.name, "infer_long_type")
+        df.write.parquet(output_dir)
+        df1 = self.spark.read.parquet(output_dir)
+        self.assertEqual('a', df1.first().f1)
+        self.assertEqual(100000000000000, df1.first().f2)
+
+        self.assertEqual(_infer_type(1), LongType())
+        self.assertEqual(_infer_type(2**10), LongType())
+        self.assertEqual(_infer_type(2**20), LongType())
+        self.assertEqual(_infer_type(2**31 - 1), LongType())
+        self.assertEqual(_infer_type(2**31), LongType())
+        self.assertEqual(_infer_type(2**61), LongType())
+        self.assertEqual(_infer_type(2**71), LongType())
+
+    def test_merge_type(self):
+        self.assertEqual(_merge_type(LongType(), NullType()), LongType())
+        self.assertEqual(_merge_type(NullType(), LongType()), LongType())
+
+        self.assertEqual(_merge_type(LongType(), LongType()), LongType())
+
+        self.assertEqual(_merge_type(
+            ArrayType(LongType()),
+            ArrayType(LongType())
+        ), ArrayType(LongType()))
+        with self.assertRaisesRegexp(TypeError, 'element in array'):
+            _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+
+        self.assertEqual(_merge_type(
+            MapType(StringType(), LongType()),
+            MapType(StringType(), LongType())
+        ), MapType(StringType(), LongType()))
+        with self.assertRaisesRegexp(TypeError, 'key of map'):
+            _merge_type(
+                MapType(StringType(), LongType()),
+                MapType(DoubleType(), LongType()))
+        with self.assertRaisesRegexp(TypeError, 'value of map'):
+            _merge_type(
+                MapType(StringType(), LongType()),
+                MapType(StringType(), DoubleType()))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
+            StructType([StructField("f1", LongType()), StructField("f2", 
StringType())])
+        ), StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'field f1'):
+            _merge_type(
+                StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
+                StructType([StructField("f1", DoubleType()), StructField("f2", 
StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]),
+            StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))])
+        ), StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]))
+        with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
+            _merge_type(
+                StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]),
+                StructType([StructField("f1", StructType([StructField("f2", 
StringType())]))]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]),
+            StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())])
+        ), StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
+            _merge_type(
+                StructType([
+                    StructField("f1", ArrayType(LongType())),
+                    StructField("f2", StringType())]),
+                StructType([
+                    StructField("f1", ArrayType(DoubleType())),
+                    StructField("f2", StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([
+                StructField("f1", MapType(StringType(), LongType())),
+                StructField("f2", StringType())]),
+            StructType([
+                StructField("f1", MapType(StringType(), LongType())),
+                StructField("f2", StringType())])
+        ), StructType([
+            StructField("f1", MapType(StringType(), LongType())),
+            StructField("f2", StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
+            _merge_type(
+                StructType([
+                    StructField("f1", MapType(StringType(), LongType())),
+                    StructField("f2", StringType())]),
+                StructType([
+                    StructField("f1", MapType(StringType(), DoubleType())),
+                    StructField("f2", StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]),
+            StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))])
+        ), StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]))
+        with self.assertRaisesRegexp(TypeError, 'key of map element in array 
field f1'):
+            _merge_type(
+                StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]),
+                StructType([StructField("f1", ArrayType(MapType(DoubleType(), 
LongType())))])
+            )
+
+    # test for SPARK-16542
+    def test_array_types(self):
+        # This test need to make sure that the Scala type selected is at least
+        # as large as the python's types. This is necessary because python's
+        # array types depend on C implementation on the machine. Therefore 
there
+        # is no machine independent correspondence between python's array types
+        # and Scala types.
+        # See: https://docs.python.org/2/library/array.html
+
+        def assertCollectSuccess(typecode, value):
+            row = Row(myarray=array.array(typecode, [value]))
+            df = self.spark.createDataFrame([row])
+            self.assertEqual(df.first()["myarray"][0], value)
+
+        # supported string types
+        #
+        # String types in python's array are "u" for Py_UNICODE and "c" for 
char.
+        # "u" will be removed in python 4, and "c" is not supported in python 
3.
+        supported_string_types = []
+        if sys.version_info[0] < 4:
+            supported_string_types += ['u']
+            # test unicode
+            assertCollectSuccess('u', u'a')
+        if sys.version_info[0] < 3:
+            supported_string_types += ['c']
+            # test string
+            assertCollectSuccess('c', 'a')
+
+        # supported float and double
+        #
+        # Test max, min, and precision for float and double, assuming IEEE 754
+        # floating-point format.
+        supported_fractional_types = ['f', 'd']
+        assertCollectSuccess('f', ctypes.c_float(1e+38).value)
+        assertCollectSuccess('f', ctypes.c_float(1e-38).value)
+        assertCollectSuccess('f', ctypes.c_float(1.123456).value)
+        assertCollectSuccess('d', sys.float_info.max)
+        assertCollectSuccess('d', sys.float_info.min)
+        assertCollectSuccess('d', sys.float_info.epsilon)
+
+        # supported signed int types
+        #
+        # The size of C types changes with implementation, we need to make sure
+        # that there is no overflow error on the platform running this test.
+        supported_signed_int_types = list(
+            set(_array_signed_int_typecode_ctype_mappings.keys())
+            .intersection(set(_array_type_mappings.keys())))
+        for t in supported_signed_int_types:
+            ctype = _array_signed_int_typecode_ctype_mappings[t]
+            max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
+            assertCollectSuccess(t, max_val - 1)
+            assertCollectSuccess(t, -max_val)
+
+        # supported unsigned int types
+        #
+        # JVM does not have unsigned types. We need to be very careful to make
+        # sure that there is no overflow error.
+        supported_unsigned_int_types = list(
+            set(_array_unsigned_int_typecode_ctype_mappings.keys())
+            .intersection(set(_array_type_mappings.keys())))
+        for t in supported_unsigned_int_types:
+            ctype = _array_unsigned_int_typecode_ctype_mappings[t]
+            assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
+
+        # all supported types
+        #
+        # Make sure the types tested above:
+        # 1. are all supported types
+        # 2. cover all supported types
+        supported_types = (supported_string_types +
+                           supported_fractional_types +
+                           supported_signed_int_types +
+                           supported_unsigned_int_types)
+        self.assertEqual(set(supported_types), 
set(_array_type_mappings.keys()))
+
+        # all unsupported types
+        #
+        # Keys in _array_type_mappings is a complete list of all supported 
types,
+        # and types not in _array_type_mappings are considered unsupported.
+        # `array.typecodes` are not supported in python 2.
+        if sys.version_info[0] < 3:
+            all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 
'f', 'd'])
+        else:
+            all_types = set(array.typecodes)
+        unsupported_types = all_types - set(supported_types)
+        # test unsupported types
+        for t in unsupported_types:
+            with self.assertRaises(TypeError):
+                a = array.array(t)
+                self.spark.createDataFrame([Row(myarray=a)]).collect()
+
+
+class DataTypeTests(unittest.TestCase):
+    # regression test for SPARK-6055
+    def test_data_type_eq(self):
+        lt = LongType()
+        lt2 = pickle.loads(pickle.dumps(LongType()))
+        self.assertEqual(lt, lt2)
+
+    # regression test for SPARK-7978
+    def test_decimal_type(self):
+        t1 = DecimalType()
+        t2 = DecimalType(10, 2)
+        self.assertTrue(t2 is not t1)
+        self.assertNotEqual(t1, t2)
+        t3 = DecimalType(8)
+        self.assertNotEqual(t2, t3)
+
+    # regression test for SPARK-10392
+    def test_datetype_equal_zero(self):
+        dt = DateType()
+        self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
+
+    # regression test for SPARK-17035
+    def test_timestamp_microsecond(self):
+        tst = TimestampType()
+        self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 
999999)
+
+    def test_empty_row(self):
+        row = Row()
+        self.assertEqual(len(row), 0)
+
+    def test_struct_field_type_name(self):
+        struct_field = StructField("a", IntegerType())
+        self.assertRaises(TypeError, struct_field.typeName)
+
+    def test_invalid_create_row(self):
+        row_class = Row("c1", "c2")
+        self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
+
+
+class DataTypeVerificationTests(unittest.TestCase):
+
+    def test_verify_type_exception_msg(self):
+        self.assertRaisesRegexp(
+            ValueError,
+            "test_name",
+            lambda: _make_type_verifier(StringType(), nullable=False, 
name="test_name")(None))
+
+        schema = StructType([StructField('a', StructType([StructField('b', 
IntegerType())]))])
+        self.assertRaisesRegexp(
+            TypeError,
+            "field b in field a",
+            lambda: _make_type_verifier(schema)([["data"]]))
+
+    def test_verify_type_ok_nullable(self):
+        obj = None
+        types = [IntegerType(), FloatType(), StringType(), StructType([])]
+        for data_type in types:
+            try:
+                _make_type_verifier(data_type, nullable=True)(obj)
+            except Exception:
+                self.fail("verify_type(%s, %s, nullable=True)" % (obj, 
data_type))
+
+    def test_verify_type_not_nullable(self):
+        import array
+        import datetime
+        import decimal
+
+        schema = StructType([
+            StructField('s', StringType(), nullable=False),
+            StructField('i', IntegerType(), nullable=True)])
+
+        class MyObj:
+            def __init__(self, **kwargs):
+                for k, v in kwargs.items():
+                    setattr(self, k, v)
+
+        # obj, data_type
+        success_spec = [
+            # String
+            ("", StringType()),
+            (u"", StringType()),
+            (1, StringType()),
+            (1.0, StringType()),
+            ([], StringType()),
+            ({}, StringType()),
+
+            # UDT
+            (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
+
+            # Boolean
+            (True, BooleanType()),
+
+            # Byte
+            (-(2**7), ByteType()),
+            (2**7 - 1, ByteType()),
+
+            # Short
+            (-(2**15), ShortType()),
+            (2**15 - 1, ShortType()),
+
+            # Integer
+            (-(2**31), IntegerType()),
+            (2**31 - 1, IntegerType()),
+
+            # Long
+            (2**64, LongType()),
+
+            # Float & Double
+            (1.0, FloatType()),
+            (1.0, DoubleType()),
+
+            # Decimal
+            (decimal.Decimal("1.0"), DecimalType()),
+
+            # Binary
+            (bytearray([1, 2]), BinaryType()),
+
+            # Date/Timestamp
+            (datetime.date(2000, 1, 2), DateType()),
+            (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
+            (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
+
+            # Array
+            ([], ArrayType(IntegerType())),
+            (["1", None], ArrayType(StringType(), containsNull=True)),
+            ([1, 2], ArrayType(IntegerType())),
+            ((1, 2), ArrayType(IntegerType())),
+            (array.array('h', [1, 2]), ArrayType(IntegerType())),
+
+            # Map
+            ({}, MapType(StringType(), IntegerType())),
+            ({"a": 1}, MapType(StringType(), IntegerType())),
+            ({"a": None}, MapType(StringType(), IntegerType(), 
valueContainsNull=True)),
+
+            # Struct
+            ({"s": "a", "i": 1}, schema),
+            ({"s": "a", "i": None}, schema),
+            ({"s": "a"}, schema),
+            ({"s": "a", "f": 1.0}, schema),
+            (Row(s="a", i=1), schema),
+            (Row(s="a", i=None), schema),
+            (Row(s="a", i=1, f=1.0), schema),
+            (["a", 1], schema),
+            (["a", None], schema),
+            (("a", 1), schema),
+            (MyObj(s="a", i=1), schema),
+            (MyObj(s="a", i=None), schema),
+            (MyObj(s="a"), schema),
+        ]
+
+        # obj, data_type, exception class
+        failure_spec = [
+            # String (match anything but None)
+            (None, StringType(), ValueError),
+
+            # UDT
+            (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
+
+            # Boolean
+            (1, BooleanType(), TypeError),
+            ("True", BooleanType(), TypeError),
+            ([1], BooleanType(), TypeError),
+
+            # Byte
+            (-(2**7) - 1, ByteType(), ValueError),
+            (2**7, ByteType(), ValueError),
+            ("1", ByteType(), TypeError),
+            (1.0, ByteType(), TypeError),
+
+            # Short
+            (-(2**15) - 1, ShortType(), ValueError),
+            (2**15, ShortType(), ValueError),
+
+            # Integer
+            (-(2**31) - 1, IntegerType(), ValueError),
+            (2**31, IntegerType(), ValueError),
+
+            # Float & Double
+            (1, FloatType(), TypeError),
+            (1, DoubleType(), TypeError),
+
+            # Decimal
+            (1.0, DecimalType(), TypeError),
+            (1, DecimalType(), TypeError),
+            ("1.0", DecimalType(), TypeError),
+
+            # Binary
+            (1, BinaryType(), TypeError),
+
+            # Date/Timestamp
+            ("2000-01-02", DateType(), TypeError),
+            (946811040, TimestampType(), TypeError),
+
+            # Array
+            (["1", None], ArrayType(StringType(), containsNull=False), 
ValueError),
+            ([1, "2"], ArrayType(IntegerType()), TypeError),
+
+            # Map
+            ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
+            ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
+            ({"a": None}, MapType(StringType(), IntegerType(), 
valueContainsNull=False),
+             ValueError),
+
+            # Struct
+            ({"s": "a", "i": "1"}, schema, TypeError),
+            (Row(s="a"), schema, ValueError),     # Row can't have missing 
field
+            (Row(s="a", i="1"), schema, TypeError),
+            (["a"], schema, ValueError),
+            (["a", "1"], schema, TypeError),
+            (MyObj(s="a", i="1"), schema, TypeError),
+            (MyObj(s=None, i="1"), schema, ValueError),
+        ]
+
+        # Check success cases
+        for obj, data_type in success_spec:
+            try:
+                _make_type_verifier(data_type, nullable=False)(obj)
+            except Exception:
+                self.fail("verify_type(%s, %s, nullable=False)" % (obj, 
data_type))
+
+        # Check failure cases
+        for obj, data_type, exp in failure_spec:
+            msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, 
data_type, exp)
+            with self.assertRaises(exp, msg=msg):
+                _make_type_verifier(data_type, nullable=False)(obj)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_types import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to