http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py new file mode 100644 index 0000000..b811047 --- /dev/null +++ b/python/pyspark/sql/tests/test_session.py @@ -0,0 +1,320 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import unittest + +from pyspark import SparkConf, SparkContext +from pyspark.sql import SparkSession, SQLContext, Row +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.tests import PySparkTestCase + + +class SparkSessionTests(ReusedSQLTestCase): + def test_sqlcontext_reuses_sparksession(self): + sqlContext1 = SQLContext(self.sc) + sqlContext2 = SQLContext(self.sc) + self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + + +class SparkSessionTests1(ReusedSQLTestCase): + + # We can't include this test into SQLTests because we will stop class's SparkContext and cause + # other tests failed. + def test_sparksession_with_stopped_sparkcontext(self): + self.sc.stop() + sc = SparkContext('local[4]', self.sc.appName) + spark = SparkSession.builder.getOrCreate() + try: + df = spark.createDataFrame([(1, 2)], ["c", "c"]) + df.collect() + finally: + spark.stop() + sc.stop() + + +class SparkSessionTests2(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + +class SparkSessionTests3(unittest.TestCase): + + def test_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + activeSession = SparkSession.getActiveSession() + df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) + self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) + finally: + spark.stop() + + def test_get_active_session_when_no_active_session(self): + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + active = SparkSession.getActiveSession() + self.assertEqual(active, spark) + spark.stop() + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + + def test_SparkSession(self): + spark = SparkSession.builder \ + .master("local") \ + .config("some-config", "v2") \ + .getOrCreate() + try: + self.assertEqual(spark.conf.get("some-config"), "v2") + self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") + self.assertEqual(spark.version, spark.sparkContext.version) + spark.sql("CREATE DATABASE test_db") + spark.catalog.setCurrentDatabase("test_db") + self.assertEqual(spark.catalog.currentDatabase(), "test_db") + spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") + self.assertEqual(spark.table("table1").columns, ['name', 'age']) + self.assertEqual(spark.range(3).count(), 3) + finally: + spark.stop() + + def test_global_default_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertEqual(SparkSession.builder.getOrCreate(), spark) + finally: + spark.stop() + + def test_default_and_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + activeSession = spark._jvm.SparkSession.getActiveSession() + defaultSession = spark._jvm.SparkSession.getDefaultSession() + try: + self.assertEqual(activeSession, defaultSession) + finally: + spark.stop() + + def test_config_option_propagated_to_existing_session(self): + session1 = SparkSession.builder \ + .master("local") \ + .config("spark-config1", "a") \ + .getOrCreate() + self.assertEqual(session1.conf.get("spark-config1"), "a") + session2 = SparkSession.builder \ + .config("spark-config1", "b") \ + .getOrCreate() + try: + self.assertEqual(session1, session2) + self.assertEqual(session1.conf.get("spark-config1"), "b") + finally: + session1.stop() + + def test_new_session(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + newSession = session.newSession() + try: + self.assertNotEqual(session, newSession) + finally: + session.stop() + newSession.stop() + + def test_create_new_session_if_old_session_stopped(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + session.stop() + newSession = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertNotEqual(session, newSession) + finally: + newSession.stop() + + def test_active_session_with_None_and_not_None_context(self): + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + sc = None + session = None + try: + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + +class SparkSessionTests4(ReusedSQLTestCase): + + def test_get_active_session_after_create_dataframe(self): + session2 = None + try: + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + finally: + if session2 is not None: + session2.stop() + + +class SparkSessionBuilderTests(unittest.TestCase): + + def test_create_spark_context_first_then_spark_session(self): + sc = None + session = None + try: + conf = SparkConf().set("key1", "value1") + sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) + session = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session.conf.get("key1"), "value1") + self.assertEqual(session.conf.get("key2"), "value2") + self.assertEqual(session.sparkContext, sc) + + self.assertFalse(sc.getConf().contains("key2")) + self.assertEqual(sc.getConf().get("key1"), "value1") + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + def test_another_spark_session(self): + session1 = None + session2 = None + try: + session1 = SparkSession.builder.config("key1", "value1").getOrCreate() + session2 = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session1.conf.get("key1"), "value1") + self.assertEqual(session2.conf.get("key1"), "value1") + self.assertEqual(session1.conf.get("key2"), "value2") + self.assertEqual(session2.conf.get("key2"), "value2") + self.assertEqual(session1.sparkContext, session2.sparkContext) + + self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") + self.assertFalse(session1.sparkContext.getConf().contains("key2")) + finally: + if session1 is not None: + session1.stop() + if session2 is not None: + session2.stop() + + +class SparkExtensionsTest(unittest.TestCase): + # These tests are separate because it uses 'spark.sql.extensions' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "SparkSessionExtensionSuite.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.extensions' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.extensions", + "org.apache.spark.sql.MyExtensions") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def test_use_custom_class_for_extensions(self): + self.assertTrue( + self.spark._jsparkSession.sessionState().planner().strategies().contains( + self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), + "MySparkStrategy not found in active planner strategies") + self.assertTrue( + self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( + self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), + "MyRule not found in extended resolution rules") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_session import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2)
http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_streaming.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py new file mode 100644 index 0000000..cc0cab4 --- /dev/null +++ b/python/pyspark/sql/tests/test_streaming.py @@ -0,0 +1,566 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tempfile +import time + +from pyspark.sql.functions import lit +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTests(ReusedSQLTestCase): + + def test_stream_trigger(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + + # Should take at least one arg + try: + df.writeStream.trigger() + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(once=True, processingTime='5 seconds') + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + + # Should take only keyword args + try: + df.writeStream.trigger('5 seconds') + self.fail("Should have thrown an exception") + except TypeError: + pass + + def test_stream_read_options(self): + schema = StructType([StructField("data", StringType(), False)]) + df = self.spark.readStream\ + .format('text')\ + .option('path', 'python/test_support/sql/streaming')\ + .schema(schema)\ + .load() + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_read_options_overwrite(self): + bad_schema = StructType([StructField("test", IntegerType(), False)]) + schema = StructType([StructField("data", StringType(), False)]) + df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema)\ + .load(path='python/test_support/sql/streaming', schema=schema, format='text') + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct<data:string>") + + def test_stream_save_options(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ + .withColumn('id', lit(1)) + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ + .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() + try: + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_save_options_overwrite(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + fake1 = os.path.join(tmpPath, 'fake1') + fake2 = os.path.join(tmpPath, 'fake2') + q = df.writeStream.option('checkpointLocation', fake1)\ + .format('memory').option('path', fake2) \ + .queryName('fake_query').outputMode('append') \ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + + try: + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_status_and_progress(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + + def func(x): + time.sleep(1) + return x + + from pyspark.sql.functions import col, udf + sleep_udf = udf(func) + + # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there + # were no updates. + q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + # "lastProgress" will return None in most cases. However, as it may be flaky when + # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" + # may throw error with a high chance and make this test flaky, so we should still be + # able to detect broken codes. + q.lastProgress + + q.processAllAvailable() + lastProgress = q.lastProgress + recentProgress = q.recentProgress + status = q.status + self.assertEqual(lastProgress['name'], q.name) + self.assertEqual(lastProgress['id'], q.id) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + self.assertTrue( + "message" in status and + "isDataAvailable" in status and + "isTriggerActive" in status) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_await_termination(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + self.assertTrue(q.isActive) + try: + q.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = q.awaitTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_exception(self): + sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + sq = sdf.writeStream.format('memory').queryName('query_explain').start() + try: + sq.processAllAvailable() + self.assertEqual(sq.exception(), None) + finally: + sq.stop() + + from pyspark.sql.functions import col, udf + from pyspark.sql.utils import StreamingQueryException + bad_udf = udf(lambda x: 1 / 0) + sq = sdf.select(bad_udf(col("value")))\ + .writeStream\ + .format('memory')\ + .queryName('this_query')\ + .start() + try: + # Process some data to fail the query + sq.processAllAvailable() + self.fail("bad udf should fail the query") + except StreamingQueryException as e: + # This is expected + self.assertTrue("ZeroDivisionError" in e.desc) + finally: + sq.stop() + self.assertTrue(type(sq.exception()) is StreamingQueryException) + self.assertTrue("ZeroDivisionError" in sq.exception().desc) + + def test_query_manager_await_termination(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + self.assertTrue(q.isActive) + try: + self.spark._wrapped.streams.awaitAnyTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = self.spark._wrapped.streams.awaitAnyTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + q.stop() + shutil.rmtree(tmpPath) + + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules + # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html + # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_streaming import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py new file mode 100644 index 0000000..3b32c58 --- /dev/null +++ b/python/pyspark/sql/tests/test_types.py @@ -0,0 +1,944 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import array +import ctypes +import datetime +import os +import pickle +import sys +import unittest + +from pyspark.sql import Row +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.types import * +from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \ + _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type +from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \ + ExamplePoint, PythonOnlyPoint, MyObject + + +class TypesTests(ReusedSQLTestCase): + + def test_apply_schema_to_row(self): + df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.spark.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.spark.createDataFrame(input) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_apply_schema_to_dict_and_rows(self): + schema = StructType().add("b", StringType()).add("a", IntegerType()) + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + for verify in [False, True]: + df = self.spark.createDataFrame(input, schema, verifySchema=verify) + df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(10, df3.count()) + input = [Row(a=x, b=str(x)) for x in range(10)] + df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) + self.assertEqual(10, df4.count()) + + def test_create_dataframe_schema_mismatch(self): + input = [Row(a=1)] + rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) + schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + df = self.spark.createDataFrame(rdd, schema) + self.assertRaises(Exception, lambda: df.show()) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.spark.createDataFrame(rdd) + self.assertEqual([], df.rdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) + + with self.tempView("test"): + df.createOrReplaceTempView("test") + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) + + with self.tempView("test2"): + df2.createOrReplaceTempView("test2") + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_schema_specification(self): + from decimal import Decimal + + class A(object): + def __init__(self): + self.a = 1 + + data = [ + True, + 1, + "a", + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + array.array("d", [1]), + [1], + (1, ), + {"a": 1}, + bytearray(1), + Decimal(1), + Row(a=1), + Row("a")(1), + A(), + ] + + df = self.spark.createDataFrame([data]) + actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) + expected = [ + 'boolean', + 'bigint', + 'string', + 'string', + 'date', + 'timestamp', + 'double', + 'array<double>', + 'array<bigint>', + 'struct<_1:bigint>', + 'map<string,bigint>', + 'binary', + 'decimal(38,18)', + 'struct<a:bigint>', + 'struct<a:bigint>', + 'struct<a:bigint>', + ] + self.assertEqual(actual, expected) + + actual = list(df.first()) + expected = [ + True, + 1, + 'a', + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + [1.0], + [1], + Row(_1=1), + {"a": 1}, + bytearray(b'\x00'), + Decimal('1.000000000000000000'), + Row(a=1), + Row(a=1), + Row(a=1), + ] + self.assertEqual(actual, expected) + + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.spark.createDataFrame(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.spark.createDataFrame(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.spark.createDataFrame(rdd) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.spark.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.spark.createDataFrame(rdd, schema) + results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, + x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + with self.tempView("table2"): + df.createOrReplaceTempView("table2") + r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + + with self.tempView("test"): + df.createOrReplaceTempView("test") + row = self.spark.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) + + def test_simple_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + def test_nested_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.collect() + + schema = StructType().add("key", LongType()).add("val", + MapType(LongType(), PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) + df.collect() + + def test_complex_nested_udt_in_df(self): + from pyspark.sql.functions import udf + + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + gd = df.groupby("key").agg({"val": "collect_list"}) + gd.collect() + udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + gd.select(udf(*gd)).collect() + + def test_udt_with_none(self): + df = self.spark.range(0, 10, 1, 1) + + def myudf(x): + if x > 0: + return PythonOnlyPoint(float(x), float(x)) + + self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) + rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] + self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + + def test_infer_schema_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + row = (1.0, ExamplePoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_udf_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + def test_parquet_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.spark.createDataFrame([row]) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.write.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.spark.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.spark.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_union_with_udt(self): + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.spark.createDataFrame([row1], schema) + df2 = self.spark.createDataFrame([row2], schema) + + result = df1.union(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + + def test_cast_to_string_with_udt(self): + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + + def test_struct_type(self): + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + self.assertRaises(ValueError, lambda: StructType().add("name")) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + for field in struct1: + self.assertIsInstance(field, StructField) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(len(struct1), 2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[0], struct1.fields[0]) + self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) + + def test_parse_datatype_string(self): + from pyspark.sql.types import _all_atomic_types, _parse_datatype_string + for k, t in _all_atomic_types.items(): + if t != NullType: + self.assertEqual(t(), _parse_datatype_string(k)) + self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) + self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) + self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) + self.assertEqual( + ArrayType(IntegerType()), + _parse_datatype_string("array<int >")) + self.assertEqual( + MapType(IntegerType(), DoubleType()), + _parse_datatype_string("map< int, double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("struct<a:int, c:double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a:int, c:double")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a INT, c DOUBLE")) + + def test_metadata_null(self): + schema = StructType([StructField("f1", StringType(), True, None), + StructField("f2", StringType(), True, {'a': None})]) + rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) + self.spark.createDataFrame(rdd, schema) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.write.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + + # test for SPARK-16542 + def test_array_types(self): + # This test need to make sure that the Scala type selected is at least + # as large as the python's types. This is necessary because python's + # array types depend on C implementation on the machine. Therefore there + # is no machine independent correspondence between python's array types + # and Scala types. + # See: https://docs.python.org/2/library/array.html + + def assertCollectSuccess(typecode, value): + row = Row(myarray=array.array(typecode, [value])) + df = self.spark.createDataFrame([row]) + self.assertEqual(df.first()["myarray"][0], value) + + # supported string types + # + # String types in python's array are "u" for Py_UNICODE and "c" for char. + # "u" will be removed in python 4, and "c" is not supported in python 3. + supported_string_types = [] + if sys.version_info[0] < 4: + supported_string_types += ['u'] + # test unicode + assertCollectSuccess('u', u'a') + if sys.version_info[0] < 3: + supported_string_types += ['c'] + # test string + assertCollectSuccess('c', 'a') + + # supported float and double + # + # Test max, min, and precision for float and double, assuming IEEE 754 + # floating-point format. + supported_fractional_types = ['f', 'd'] + assertCollectSuccess('f', ctypes.c_float(1e+38).value) + assertCollectSuccess('f', ctypes.c_float(1e-38).value) + assertCollectSuccess('f', ctypes.c_float(1.123456).value) + assertCollectSuccess('d', sys.float_info.max) + assertCollectSuccess('d', sys.float_info.min) + assertCollectSuccess('d', sys.float_info.epsilon) + + # supported signed int types + # + # The size of C types changes with implementation, we need to make sure + # that there is no overflow error on the platform running this test. + supported_signed_int_types = list( + set(_array_signed_int_typecode_ctype_mappings.keys()) + .intersection(set(_array_type_mappings.keys()))) + for t in supported_signed_int_types: + ctype = _array_signed_int_typecode_ctype_mappings[t] + max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1) + assertCollectSuccess(t, max_val - 1) + assertCollectSuccess(t, -max_val) + + # supported unsigned int types + # + # JVM does not have unsigned types. We need to be very careful to make + # sure that there is no overflow error. + supported_unsigned_int_types = list( + set(_array_unsigned_int_typecode_ctype_mappings.keys()) + .intersection(set(_array_type_mappings.keys()))) + for t in supported_unsigned_int_types: + ctype = _array_unsigned_int_typecode_ctype_mappings[t] + assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1) + + # all supported types + # + # Make sure the types tested above: + # 1. are all supported types + # 2. cover all supported types + supported_types = (supported_string_types + + supported_fractional_types + + supported_signed_int_types + + supported_unsigned_int_types) + self.assertEqual(set(supported_types), set(_array_type_mappings.keys())) + + # all unsupported types + # + # Keys in _array_type_mappings is a complete list of all supported types, + # and types not in _array_type_mappings are considered unsupported. + # `array.typecodes` are not supported in python 2. + if sys.version_info[0] < 3: + all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd']) + else: + all_types = set(array.typecodes) + unsupported_types = all_types - set(supported_types) + # test unsupported types + for t in unsupported_types: + with self.assertRaises(TypeError): + a = array.array(t) + self.spark.createDataFrame([Row(myarray=a)]).collect() + + +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEqual(lt, lt2) + + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + + # regression test for SPARK-17035 + def test_timestamp_microsecond(self): + tst = TimestampType() + self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999) + + def test_empty_row(self): + row = Row() + self.assertEqual(len(row), 0) + + def test_struct_field_type_name(self): + struct_field = StructField("a", IntegerType()) + self.assertRaises(TypeError, struct_field.typeName) + + def test_invalid_create_row(self): + row_class = Row("c1", "c2") + self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) + + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_types import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org