http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/__init__.py b/python/pyspark/sql/tests/__init__.py new file mode 100644 index 0000000..cce3aca --- /dev/null +++ b/python/pyspark/sql/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#
http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_appsubmit.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_appsubmit.py b/python/pyspark/sql/tests/test_appsubmit.py new file mode 100644 index 0000000..3c71151 --- /dev/null +++ b/python/pyspark/sql/tests/test_appsubmit.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import subprocess +import tempfile + +import py4j + +from pyspark import SparkContext +from pyspark.tests import SparkSubmitTests + + +class HiveSparkSubmitTests(SparkSubmitTests): + + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + + def test_hivecontext(self): + # This test checks that HiveContext is using Hive metastore (SPARK-16224). + # It sets a metastore url and checks if there is a derby dir created by + # Hive metastore. If this derby dir exists, HiveContext is using + # Hive metastore. + metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db") + metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true" + hive_site_dir = os.path.join(self.programDir, "conf") + hive_site_file = self.createTempFile("hive-site.xml", (""" + |<configuration> + | <property> + | <name>javax.jdo.option.ConnectionURL</name> + | <value>%s</value> + | </property> + |</configuration> + """ % metastore_URL).lstrip(), "conf") + script = self.createTempFile("test.py", """ + |import os + | + |from pyspark.conf import SparkConf + |from pyspark.context import SparkContext + |from pyspark.sql import HiveContext + | + |conf = SparkConf() + |sc = SparkContext(conf=conf) + |hive_context = HiveContext(sc) + |print(hive_context.sql("show databases").collect()) + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("default", out.decode('utf-8')) + self.assertTrue(os.path.exists(metastore_path)) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_appsubmit import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_arrow.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py new file mode 100644 index 0000000..44f7035 --- /dev/null +++ b/python/pyspark/sql/tests/test_arrow.py @@ -0,0 +1,399 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import os +import threading +import time +import unittest +import warnings + +from pyspark.sql import Row +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.tests import QuietTest +from pyspark.util import _exception_message + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class ArrowTests(ReusedSQLTestCase): + + @classmethod + def setUpClass(cls): + from datetime import date, datetime + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + super(ArrowTests, cls).setUpClass() + cls.warnings_lock = threading.Lock() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.spark.conf.set("spark.sql.session.timeZone", tz) + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True), + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + super(ArrowTests, cls).tearDownClass() + + def create_pandas_data_frame(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + return pd.DataFrame(data=data_dict) + + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with self.warnings_lock: + with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.warnings_lock: + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def _toPandas_arrow_toggle(self, df): + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): + pdf = df.toPandas() + + pdf_arrow = df.toPandas() + + return pdf, pdf_arrow + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) + + def test_toPandas_respect_session_timezone(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): + pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) + + self.assertFalse(pdf_ny.equals(pdf_la)) + + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + pdf_la_corrected = pdf_la.copy() + for field in self.schema: + if isinstance(field.dataType, TimestampType): + pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( + pdf_la_corrected[field.name], timezone) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) + + def test_pandas_round_trip(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertPandasEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + def _createDataFrame_toggle(self, pdf, schema=None): + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): + df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) + + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + + return df_no_arrow, df_arrow + + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) + self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + + def test_createDataFrame_respect_session_timezone(self): + from datetime import timedelta + pdf = self.create_pandas_data_frame() + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) + result_ny = df_no_arrow_ny.collect() + result_arrow_ny = df_arrow_ny.collect() + self.assertEqual(result_ny, result_arrow_ny) + + self.assertNotEqual(result_ny, result_la) + + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v + for k, v in row.asDict().items()}) + for row in result_la] + self.assertEqual(result_ny, result_la_corrected) + + def test_createDataFrame_with_schema(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(pdf, schema=self.schema) + self.assertEquals(self.schema, df.schema) + pdf_arrow = df.toPandas() + self.assertPandasEqual(pdf_arrow, pdf) + + def test_createDataFrame_with_incorrect_schema(self): + pdf = self.create_pandas_data_frame() + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): + self.spark.createDataFrame(pdf, schema=wrong_schema) + + def test_createDataFrame_with_names(self): + pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) + + def test_createDataFrame_column_name_encoding(self): + import pandas as pd + pdf = pd.DataFrame({u'a': [1]}) + columns = self.spark.createDataFrame(pdf).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'a') + columns = self.spark.createDataFrame(pdf, [u'b']).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'b') + + def test_createDataFrame_with_single_data_type(self): + import pandas as pd + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): + self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") + + def test_createDataFrame_does_not_modify_input(self): + import pandas as pd + # Some series get converted for Spark to consume, this makes sure input is unchanged + pdf = self.create_pandas_data_frame() + # Use a nanosecond value to make sure it is not truncated + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) + # Integers with nulls will get NaNs filled with 0 and will be casted + pdf.ix[1, '2_int_t'] = None + pdf_copy = pdf.copy(deep=True) + self.spark.createDataFrame(pdf, schema=self.schema) + self.assertTrue(pdf.equals(pdf_copy)) + + def test_schema_conversion_roundtrip(self): + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + arrow_schema = to_arrow_schema(self.schema) + schema_rt = from_arrow_schema(arrow_schema) + self.assertEquals(self.schema, schema_rt) + + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion + import pandas as pd + import pyarrow as pa + + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") + + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + + +class EncryptionArrowTests(ArrowTests): + + @classmethod + def conf(cls): + return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_arrow import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_catalog.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py new file mode 100644 index 0000000..23d2577 --- /dev/null +++ b/python/pyspark/sql/tests/test_catalog.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class CatalogTests(ReusedSQLTestCase): + + def test_current_database(self): + spark = self.spark + with self.database("some_db"): + self.assertEquals(spark.catalog.currentDatabase(), "default") + spark.sql("CREATE DATABASE some_db") + spark.catalog.setCurrentDatabase("some_db") + self.assertEquals(spark.catalog.currentDatabase(), "some_db") + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.setCurrentDatabase("does_not_exist")) + + def test_list_databases(self): + spark = self.spark + with self.database("some_db"): + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(databases, ["default"]) + spark.sql("CREATE DATABASE some_db") + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(sorted(databases), ["default", "some_db"]) + + def test_list_tables(self): + from pyspark.sql.catalog import Table + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + with self.table("tab1", "some_db.tab2"): + with self.tempView("temp_tab"): + self.assertEquals(spark.catalog.listTables(), []) + self.assertEquals(spark.catalog.listTables("some_db"), []) + spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") + tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) + tablesDefault = \ + sorted(spark.catalog.listTables("default"), key=lambda t: t.name) + tablesSomeDb = \ + sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) + self.assertEquals(tables, tablesDefault) + self.assertEquals(len(tables), 2) + self.assertEquals(len(tablesSomeDb), 2) + self.assertEquals(tables[0], Table( + name="tab1", + database="default", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tables[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertEquals(tablesSomeDb[0], Table( + name="tab2", + database="some_db", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tablesSomeDb[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listTables("does_not_exist")) + + def test_list_functions(self): + from pyspark.sql.catalog import Function + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + functions = dict((f.name, f) for f in spark.catalog.listFunctions()) + functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) + self.assertTrue(len(functions) > 200) + self.assertTrue("+" in functions) + self.assertTrue("like" in functions) + self.assertTrue("month" in functions) + self.assertTrue("to_date" in functions) + self.assertTrue("to_timestamp" in functions) + self.assertTrue("to_unix_timestamp" in functions) + self.assertTrue("current_database" in functions) + self.assertEquals(functions["+"], Function( + name="+", + description=None, + className="org.apache.spark.sql.catalyst.expressions.Add", + isTemporary=True)) + self.assertEquals(functions, functionsDefault) + + with self.function("func1", "some_db.func2"): + spark.catalog.registerFunction("temp_func", lambda x: str(x)) + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") + newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) + newFunctionsSomeDb = \ + dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) + self.assertTrue(set(functions).issubset(set(newFunctions))) + self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) + self.assertTrue("temp_func" in newFunctions) + self.assertTrue("func1" in newFunctions) + self.assertTrue("func2" not in newFunctions) + self.assertTrue("temp_func" in newFunctionsSomeDb) + self.assertTrue("func1" not in newFunctionsSomeDb) + self.assertTrue("func2" in newFunctionsSomeDb) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listFunctions("does_not_exist")) + + def test_list_columns(self): + from pyspark.sql.catalog import Column + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + with self.table("tab1", "some_db.tab2"): + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") + spark.sql( + "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet") + columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) + columnsDefault = \ + sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) + self.assertEquals(columns, columnsDefault) + self.assertEquals(len(columns), 2) + self.assertEquals(columns[0], Column( + name="age", + description=None, + dataType="int", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns[1], Column( + name="name", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + columns2 = \ + sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) + self.assertEquals(len(columns2), 2) + self.assertEquals(columns2[0], Column( + name="nickname", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns2[1], Column( + name="tolerance", + description=None, + dataType="float", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertRaisesRegexp( + AnalysisException, + "tab2", + lambda: spark.catalog.listColumns("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listColumns("does_not_exist")) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_catalog import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_column.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py new file mode 100644 index 0000000..faadde9 --- /dev/null +++ b/python/pyspark/sql/tests/test_column.py @@ -0,0 +1,157 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark.sql import Column, Row +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ColumnTests(ReusedSQLTestCase): + + def test_column_name_encoding(self): + """Ensure that created columns has `str` type consistently.""" + columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns + self.assertEqual(columns, ['name', 'age']) + self.assertTrue(isinstance(columns[0], str)) + self.assertTrue(isinstance(columns[1], str)) + + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + + def test_validate_column_types(self): + from pyspark.sql.functions import udf, to_json + from pyspark.sql.column import _to_java_column + + self.assertTrue("Column" in _to_java_column("a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: _to_java_column(1)) + + class A(): + pass + + self.assertRaises(TypeError, lambda: _to_java_column(A())) + self.assertRaises(TypeError, lambda: _to_java_column([])) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: udf(lambda x: x)(None)) + self.assertRaises(TypeError, lambda: to_json(1)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ + cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + self.assertRaisesRegexp(ValueError, + "Cannot apply 'in' operator against a column", + lambda: 1 in cs) + + def test_column_getitem(self): + from pyspark.sql.functions import col + + self.assertIsInstance(col("foo")[1:3], Column) + self.assertIsInstance(col("foo")[0], Column) + self.assertIsInstance(col("foo")["bar"], Column) + self.assertRaises(ValueError, lambda: col("foo")[0:10:2]) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_column_name_with_non_ascii(self): + if sys.version >= '3': + columnName = "æ°é" + self.assertTrue(isinstance(columnName, str)) + else: + columnName = unicode("æ°é", "utf-8") + self.assertTrue(isinstance(columnName, unicode)) + schema = StructType([StructField(columnName, LongType(), True)]) + df = self.spark.createDataFrame([(1,)], schema) + self.assertEqual(schema, df.schema) + self.assertEqual("DataFrame[æ°é: bigint]", str(df)) + self.assertEqual([("æ°é", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("æ°é").first()[0]) + self.assertEqual(1, df.select(df["æ°é"]).first()[0]) + + def test_field_accessor(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) + self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + + def test_bitwise_operations(self): + from pyspark.sql import functions + row = Row(a=170, b=75) + df = self.spark.createDataFrame([row]) + result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() + self.assertEqual(170 & 75, result['(a & b)']) + result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() + self.assertEqual(170 | 75, result['(a | b)']) + result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() + self.assertEqual(170 ^ 75, result['(a ^ b)']) + result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() + self.assertEqual(~75, result['~b']) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_column import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_conf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py new file mode 100644 index 0000000..f5d68a8 --- /dev/null +++ b/python/pyspark/sql/tests/test_conf.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ConfTests(ReusedSQLTestCase): + + def test_conf(self): + spark = self.spark + spark.conf.set("bogo", "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") + spark.conf.set("bogo", "ta") + self.assertEqual(spark.conf.get("bogo"), "ta") + self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") + self.assertEqual(spark.conf.get("not.set", "ta"), "ta") + self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) + spark.conf.unset("bogo") + self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_conf import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py new file mode 100644 index 0000000..d9d408a --- /dev/null +++ b/python/pyspark/sql/tests/test_context.py @@ -0,0 +1,263 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest + +import py4j + +from pyspark import HiveContext, Row +from pyspark.sql.types import * +from pyspark.sql.window import Window +from pyspark.tests import ReusedPySparkTestCase + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + os.unlink(cls.tempdir.name) + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) + actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE externalJsonTable") + + df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_window_functions(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.row_number().over(w), + F.rank().over(w), + F.dense_rank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_without_partitionBy(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.row_number().over(w), + F.rank().over(w), + F.dense_rank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_cumulative_sum(self): + df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) + from pyspark.sql import functions as F + + # Test cumulative sum + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow + frame_end = Window.unboundedFollowing + 1 + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) + rs = sorted(sel.collect()) + expected = [("one", 3), ("two", 2)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_collect_functions(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + + def test_datetime_functions(self): + from pyspark.sql import functions + from datetime import date + df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") + parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() + self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) + + @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") + def test_unbounded_frames(self): + from unittest.mock import patch + from pyspark.sql import functions as F + from pyspark.sql import window + import importlib + + df = self.spark.range(0, 3) + + def rows_frame_match(): + return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + def range_frame_match(): + return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + with patch("sys.maxsize", 2 ** 31 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 63 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 127 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + importlib.reload(window) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_context import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org