http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py deleted file mode 100644 index ea02691..0000000 --- a/python/pyspark/sql/tests.py +++ /dev/null @@ -1,7079 +0,0 @@ -# -*- encoding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for pyspark.sql; additional tests are implemented as doctests in -individual modules. -""" -import os -import sys -import subprocess -import pydoc -import shutil -import tempfile -import threading -import pickle -import functools -import time -import datetime -import array -import ctypes -import warnings -import py4j -from contextlib import contextmanager - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark.util import _exception_message - -_pandas_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() -except ImportError as e: - # If Pandas version requirement is not satisfied, skip related tests. - _pandas_requirement_message = _exception_message(e) - -_pyarrow_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() -except ImportError as e: - # If Arrow version requirement is not satisfied, skip related tests. - _pyarrow_requirement_message = _exception_message(e) - -_test_not_compiled_message = None -try: - from pyspark.sql.utils import require_test_compiled - require_test_compiled() -except Exception as e: - _test_not_compiled_message = _exception_message(e) - -_have_pandas = _pandas_requirement_message is None -_have_pyarrow = _pyarrow_requirement_message is None -_test_compiled = _test_not_compiled_message is None - -from pyspark import SparkConf, SparkContext -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row -from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier -from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings -from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings -from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2, lit -from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException - - -class UTCOffsetTimezone(datetime.tzinfo): - """ - Specifies timezone in UTC offset - """ - - def __init__(self, offset=0): - self.ZERO = datetime.timedelta(hours=offset) - - def utcoffset(self, dt): - return self.ZERO - - def dst(self, dt): - return self.ZERO - - -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.sql.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, self.__class__) and \ - other.x == self.x and other.y == self.y - - -class PythonOnlyUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return '__main__' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return PythonOnlyPoint(datum[0], datum[1]) - - @staticmethod - def foo(): - pass - - @property - def props(self): - return {} - - -class PythonOnlyPoint(ExamplePoint): - """ - An example class to demonstrate UDT in only Python - """ - __UDT__ = PythonOnlyUDT() - - -class MyObject(object): - def __init__(self, key, value): - self.key = key - self.value = value - - -class SQLTestUtils(object): - """ - This util assumes the instance of this to have 'spark' attribute, having a spark session. - It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the - the implementation of this class has 'spark' attribute. - """ - - @contextmanager - def sql_conf(self, pairs): - """ - A convenient context manager to test some configuration specific logic. This sets - `value` to the configuration `key` and then restores it back when it exits. - """ - assert isinstance(pairs, dict), "pairs should be a dictionary." - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - keys = pairs.keys() - new_values = pairs.values() - old_values = [self.spark.conf.get(key, None) for key in keys] - for key, new_value in zip(keys, new_values): - self.spark.conf.set(key, new_value) - try: - yield - finally: - for key, old_value in zip(keys, old_values): - if old_value is None: - self.spark.conf.unset(key) - else: - self.spark.conf.set(key, old_value) - - @contextmanager - def database(self, *databases): - """ - A convenient context manager to test with some specific databases. This drops the given - databases if exist and sets current database to "default" when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for db in databases: - self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) - self.spark.catalog.setCurrentDatabase("default") - - @contextmanager - def table(self, *tables): - """ - A convenient context manager to test with some specific tables. This drops the given tables - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for t in tables: - self.spark.sql("DROP TABLE IF EXISTS %s" % t) - - @contextmanager - def tempView(self, *views): - """ - A convenient context manager to test with some specific views. This drops the given views - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for v in views: - self.spark.catalog.dropTempView(v) - - @contextmanager - def function(self, *functions): - """ - A convenient context manager to test with some specific functions. This drops the given - functions if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for f in functions: - self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) - - -class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): - @classmethod - def setUpClass(cls): - super(ReusedSQLTestCase, cls).setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - super(ReusedSQLTestCase, cls).tearDownClass() - cls.spark.stop() - - def assertPandasEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + - "\n\nResult:\n%s\n%s" % (result, result.dtypes)) - self.assertTrue(expected.equals(result), msg=msg) - - -class DataTypeTests(unittest.TestCase): - # regression test for SPARK-6055 - def test_data_type_eq(self): - lt = LongType() - lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEqual(lt, lt2) - - # regression test for SPARK-7978 - def test_decimal_type(self): - t1 = DecimalType() - t2 = DecimalType(10, 2) - self.assertTrue(t2 is not t1) - self.assertNotEqual(t1, t2) - t3 = DecimalType(8) - self.assertNotEqual(t2, t3) - - # regression test for SPARK-10392 - def test_datetype_equal_zero(self): - dt = DateType() - self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) - - # regression test for SPARK-17035 - def test_timestamp_microsecond(self): - tst = TimestampType() - self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999) - - def test_empty_row(self): - row = Row() - self.assertEqual(len(row), 0) - - def test_struct_field_type_name(self): - struct_field = StructField("a", IntegerType()) - self.assertRaises(TypeError, struct_field.typeName) - - def test_invalid_create_row(self): - row_class = Row("c1", "c2") - self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) - - -class SparkSessionBuilderTests(unittest.TestCase): - - def test_create_spark_context_first_then_spark_session(self): - sc = None - session = None - try: - conf = SparkConf().set("key1", "value1") - sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) - session = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session.conf.get("key1"), "value1") - self.assertEqual(session.conf.get("key2"), "value2") - self.assertEqual(session.sparkContext, sc) - - self.assertFalse(sc.getConf().contains("key2")) - self.assertEqual(sc.getConf().get("key1"), "value1") - finally: - if session is not None: - session.stop() - if sc is not None: - sc.stop() - - def test_another_spark_session(self): - session1 = None - session2 = None - try: - session1 = SparkSession.builder.config("key1", "value1").getOrCreate() - session2 = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session1.conf.get("key1"), "value1") - self.assertEqual(session2.conf.get("key1"), "value1") - self.assertEqual(session1.conf.get("key2"), "value2") - self.assertEqual(session2.conf.get("key2"), "value2") - self.assertEqual(session1.sparkContext, session2.sparkContext) - - self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") - self.assertFalse(session1.sparkContext.getConf().contains("key2")) - finally: - if session1 is not None: - session1.stop() - if session2 is not None: - session2.stop() - - -class SQLTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - cls.spark.catalog._reset() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.spark.createDataFrame(cls.testData) - - @classmethod - def tearDownClass(cls): - ReusedSQLTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_sqlcontext_reuses_sparksession(self): - sqlContext1 = SQLContext(self.sc) - sqlContext2 = SQLContext(self.sc) - self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) - - def test_row_should_be_read_only(self): - row = Row(a=1, b=2) - self.assertEqual(1, row.a) - - def foo(): - row.a = 3 - self.assertRaises(Exception, foo) - - row2 = self.spark.range(10).first() - self.assertEqual(0, row2.id) - - def foo2(): - row2.id = 2 - self.assertRaises(Exception, foo2) - - def test_range(self): - self.assertEqual(self.spark.range(1, 1).count(), 0) - self.assertEqual(self.spark.range(1, 0, -1).count(), 1) - self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.spark.range(-2).count(), 0) - self.assertEqual(self.spark.range(3).count(), 3) - - def test_duplicated_column_names(self): - df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) - row = df.select('*').first() - self.assertEqual(1, row[0]) - self.assertEqual(2, row[1]) - self.assertEqual("Row(c=1, c=2)", str(row)) - # Cannot access columns - self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) - self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) - self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) - - def test_column_name_encoding(self): - """Ensure that created columns has `str` type consistently.""" - columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns - self.assertEqual(columns, ['name', 'age']) - self.assertTrue(isinstance(columns[0], str)) - self.assertTrue(isinstance(columns[1], str)) - - def test_explode(self): - from pyspark.sql.functions import explode, explode_outer, posexplode_outer - d = [ - Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), - Row(a=1, intlist=[], mapfield={}), - Row(a=1, intlist=None, mapfield=None), - ] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - result = data.select(explode(data.intlist).alias("a")).select("a").collect() - self.assertEqual(result[0][0], 1) - self.assertEqual(result[1][0], 2) - self.assertEqual(result[2][0], 3) - - result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() - self.assertEqual(result[0][0], "a") - self.assertEqual(result[0][1], "b") - - result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] - self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) - - result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] - self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) - - result = [x[0] for x in data.select(explode_outer("intlist")).collect()] - self.assertEqual(result, [1, 2, 3, None, None]) - - result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] - self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) - - def test_and_in_expression(self): - self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) - self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) - self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) - self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") - self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) - self.assertRaises(ValueError, lambda: not self.df.key == 1) - - def test_udf_with_callable(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - class PlusFour: - def __call__(self, col): - if col is not None: - return col + 4 - - call = PlusFour() - pudf = UserDefinedFunction(call, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf_with_partial_function(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - def some_func(col, param): - if col is not None: - return col + param - - pfunc = functools.partial(some_func, param=4) - pudf = UserDefinedFunction(pfunc, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf(self): - self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. - sqlContext = self.spark._wrapped - sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) - [row] = sqlContext.sql("SELECT oneArg('test')").collect() - self.assertEqual(row[0], 4) - - def test_udf2(self): - with self.tempView("test"): - self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ - .createOrReplaceTempView("test") - [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf3(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], u'5') - - def test_udf_registration_return_type_none(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf_registration_return_type_not_none(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, "Invalid returnType"): - self.spark.catalog.registerFunction( - "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) - - def test_nondeterministic_udf(self): - # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - self.assertEqual(udf_random_col.deterministic, False) - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - - def test_nondeterministic_udf2(self): - import random - from pyspark.sql.functions import udf - random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() - self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) - self.assertEqual(random_udf1.deterministic, False) - [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf()).collect() - self.assertEqual(row[0], 6) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) - pydoc.render_doc(random_udf) - pydoc.render_doc(random_udf1) - pydoc.render_doc(udf(lambda x: x).asNondeterministic) - - def test_nondeterministic_udf3(self): - # regression test for SPARK-23233 - from pyspark.sql.functions import udf - f = udf(lambda x: x) - # Here we cache the JVM UDF instance. - self.spark.range(1).select(f("id")) - # This should reset the cache to set the deterministic status correctly. - f = f.asNondeterministic() - # Check the deterministic status of udf. - df = self.spark.range(1).select(f("id")) - deterministic = df._jdf.logicalPlan().projectList().head().deterministic() - self.assertFalse(deterministic) - - def test_nondeterministic_udf_in_aggregate(self): - from pyspark.sql.functions import udf, sum - import random - udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() - df = self.spark.range(10) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.groupby('id').agg(sum(udf_random_col())).collect() - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.agg(sum(udf_random_col())).collect() - - def test_chained_udf(self): - self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.spark.sql("SELECT double(1)").collect() - self.assertEqual(row[0], 2) - [row] = self.spark.sql("SELECT double(double(1))").collect() - self.assertEqual(row[0], 4) - [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() - self.assertEqual(row[0], 6) - - def test_single_udf_with_repeated_argument(self): - # regression test for SPARK-20685 - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - row = self.spark.sql("SELECT add(1, 1)").first() - self.assertEqual(tuple(row), (2, )) - - def test_multiple_udfs(self): - self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.spark.sql("SELECT double(1), double(2)").collect() - self.assertEqual(tuple(row), (2, 4)) - [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() - self.assertEqual(tuple(row), (4, 12)) - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() - self.assertEqual(tuple(row), (6, 5)) - - def test_udf_in_filter_on_top_of_outer_join(self): - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(a=1)]) - df = left.join(right, on='a', how='left_outer') - df = df.withColumn('b', udf(lambda x: 'x')(df.a)) - self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) - - def test_udf_in_filter_on_top_of_join(self): - # regression test for SPARK-18589 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.crossJoin(right).filter(f("a", "b")) - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b")) - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b"), "leftsemi") - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_and_common_filter_in_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1]) - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) - - def test_udf_and_common_filter_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_not_supported_in_join_condition(self): - # regression test for SPARK-25314 - # test python udf is not supported in join type besides left_semi and inner join. - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - - def runWithJoinType(join_type, type_string): - with self.assertRaisesRegexp( - AnalysisException, - 'Using PythonUDF.*%s is not supported.' % type_string): - left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() - runWithJoinType("full", "FullOuter") - runWithJoinType("left", "LeftOuter") - runWithJoinType("right", "RightOuter") - runWithJoinType("leftanti", "LeftAnti") - - def test_udf_without_arguments(self): - self.spark.catalog.registerFunction("foo", lambda: "bar") - [row] = self.spark.sql("SELECT foo()").collect() - self.assertEqual(row[0], "bar") - - def test_udf_with_array_type(self): - with self.tempView("test"): - d = [Row(l=list(range(3)), d={"key": list(range(5))})] - rdd = self.sc.parallelize(d) - self.spark.createDataFrame(rdd).createOrReplaceTempView("test") - self.spark.catalog.registerFunction( - "copylist", lambda l: list(l), ArrayType(IntegerType())) - self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(list(range(3)), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.spark.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.spark.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_udf_with_filter_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a < 2, BooleanType()) - sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) - self.assertEqual(sel.collect(), [Row(key=1, value='1')]) - - def test_udf_with_aggregate_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col, sum - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a == 1, BooleanType()) - sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) - self.assertEqual(sel.collect(), [Row(key=1)]) - - my_copy = udf(lambda x: x, IntegerType()) - my_add = udf(lambda a, b: int(a + b), IntegerType()) - my_strlen = udf(lambda x: len(x), IntegerType()) - sel = df.groupBy(my_copy(col("key")).alias("k"))\ - .agg(sum(my_strlen(col("value"))).alias("s"))\ - .select(my_add(col("k"), col("s")).alias("t")) - self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) - - def test_udf_in_generate(self): - from pyspark.sql.functions import udf, explode - df = self.spark.range(5) - f = udf(lambda x: list(range(x)), ArrayType(LongType())) - row = df.select(explode(f(*df))).groupBy().sum().first() - self.assertEqual(row[0], 10) - - df = self.spark.range(3) - res = df.select("id", explode(f(df.id))).collect() - self.assertEqual(res[0][0], 1) - self.assertEqual(res[0][1], 0) - self.assertEqual(res[1][0], 2) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 2) - self.assertEqual(res[2][1], 1) - - range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) - res = df.select("id", explode(range_udf(df.id))).collect() - self.assertEqual(res[0][0], 0) - self.assertEqual(res[0][1], -1) - self.assertEqual(res[1][0], 0) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 1) - self.assertEqual(res[2][1], 0) - self.assertEqual(res[3][0], 1) - self.assertEqual(res[3][1], 1) - - def test_udf_with_order_by_and_limit(self): - from pyspark.sql.functions import udf - my_copy = udf(lambda x: x, IntegerType()) - df = self.spark.range(10).orderBy("id") - res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) - res.explain(True) - self.assertEqual(res.collect(), [Row(id=0, copy=0)]) - - def test_udf_registration_returns_udf(self): - df = self.spark.range(10) - add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_three(id) AS plus_three").collect(), - df.select(add_three("id").alias("plus_three")).collect() - ) - - # This is to check if a 'SQLContext.udf' can call its alias. - sqlContext = self.spark._wrapped - add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_four(id) AS plus_four").collect(), - df.select(add_four("id").alias("plus_four")).collect() - ) - - def test_non_existed_udf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) - - # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. - sqlContext = spark._wrapped - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) - - def test_non_existed_udaf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", - lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - - def test_linesep_text(self): - df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") - expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), - Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), - Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), - Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df.write.text(tpath, lineSep="!") - expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), - Row(value=u'Tom!30!"My name is Tom"'), - Row(value=u'Hyukjin!25!"I am Hyukjin'), - Row(value=u''), Row(value=u'I love Spark!"'), - Row(value=u'!')] - readback = self.spark.read.text(tpath) - self.assertEqual(readback.collect(), expected) - finally: - shutil.rmtree(tpath) - - def test_multiline_json(self): - people1 = self.spark.read.json("python/test_support/sql/people.json") - people_array = self.spark.read.json("python/test_support/sql/people_array.json", - multiLine=True) - self.assertEqual(people1.collect(), people_array.collect()) - - def test_encoding_json(self): - people_array = self.spark.read\ - .json("python/test_support/sql/people_array_utf16le.json", - multiLine=True, encoding="UTF-16LE") - expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] - self.assertEqual(people_array.collect(), expected) - - def test_linesep_json(self): - df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") - expected = [Row(_corrupt_record=None, name=u'Michael'), - Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), - Row(_corrupt_record=u' "age":19}\n', name=None)] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df = self.spark.read.json("python/test_support/sql/people.json") - df.write.json(tpath, lineSep="!!") - readback = self.spark.read.json(tpath, lineSep="!!") - self.assertEqual(readback.collect(), df.collect()) - finally: - shutil.rmtree(tpath) - - def test_multiline_csv(self): - ages_newlines = self.spark.read.csv( - "python/test_support/sql/ages_newlines.csv", multiLine=True) - expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), - Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), - Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] - self.assertEqual(ages_newlines.collect(), expected) - - def test_ignorewhitespace_csv(self): - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( - tmpPath, - ignoreLeadingWhiteSpace=False, - ignoreTrailingWhiteSpace=False) - - expected = [Row(value=u' a,b , c ')] - readback = self.spark.read.text(tmpPath) - self.assertEqual(readback.collect(), expected) - shutil.rmtree(tmpPath) - - def test_read_multiple_orc_file(self): - df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", - "python/test_support/sql/orc_partitioned/b=1/c=1"]) - self.assertEqual(2, df.count()) - - def test_udf_with_input_file_name(self): - from pyspark.sql.functions import udf, input_file_name - sourceFile = udf(lambda path: path, StringType()) - filePath = "python/test_support/sql/people1.json" - row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() - self.assertTrue(row[0].find("people1.json") != -1) - - def test_udf_with_input_file_name_for_hadooprdd(self): - from pyspark.sql.functions import udf, input_file_name - - def filename(path): - return path - - sameText = udf(filename, StringType()) - - rdd = self.sc.textFile('python/test_support/sql/people.json') - df = self.spark.read.json(rdd).select(input_file_name().alias('file')) - row = df.select(sameText(df['file'])).first() - self.assertTrue(row[0].find("people.json") != -1) - - rdd2 = self.sc.newAPIHadoopFile( - 'python/test_support/sql/people.json', - 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', - 'org.apache.hadoop.io.LongWritable', - 'org.apache.hadoop.io.Text') - - df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) - row2 = df2.select(sameText(df2['file'])).first() - self.assertTrue(row2[0].find("people.json") != -1) - - def test_udf_defers_judf_initialization(self): - # This is separate of UDFInitializationTests - # to avoid context initialization - # when udf is called - - from pyspark.sql.functions import UserDefinedFunction - - f = UserDefinedFunction(lambda x: x, StringType()) - - self.assertIsNone( - f._judf_placeholder, - "judf should not be initialized before the first call." - ) - - self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") - - self.assertIsNotNone( - f._judf_placeholder, - "judf should be initialized after UDF has been called." - ) - - def test_udf_with_string_return_type(self): - from pyspark.sql.functions import UserDefinedFunction - - add_one = UserDefinedFunction(lambda x: x + 1, "integer") - make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>") - make_array = UserDefinedFunction( - lambda x: [float(x) for x in range(x, x + 3)], "array<double>") - - expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) - actual = (self.spark.range(1, 2).toDF("x") - .select(add_one("x"), make_pair("x"), make_array("x")) - .first()) - - self.assertTupleEqual(expected, actual) - - def test_udf_shouldnt_accept_noncallable_object(self): - from pyspark.sql.functions import UserDefinedFunction - - non_callable = None - self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) - - def test_udf_with_decorator(self): - from pyspark.sql.functions import lit, udf - from pyspark.sql.types import IntegerType, DoubleType - - @udf(IntegerType()) - def add_one(x): - if x is not None: - return x + 1 - - @udf(returnType=DoubleType()) - def add_two(x): - if x is not None: - return float(x + 2) - - @udf - def to_upper(x): - if x is not None: - return x.upper() - - @udf() - def to_lower(x): - if x is not None: - return x.lower() - - @udf - def substr(x, start, end): - if x is not None: - return x[start:end] - - @udf("long") - def trunc(x): - return int(x) - - @udf(returnType="double") - def as_double(x): - return float(x) - - df = ( - self.spark - .createDataFrame( - [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) - .select( - add_one("one"), add_two("one"), - to_upper("Foo"), to_lower("Foo"), - substr("foobar", lit(0), lit(3)), - trunc("float"), as_double("one"))) - - self.assertListEqual( - [tpe for _, tpe in df.dtypes], - ["int", "double", "string", "string", "string", "bigint", "double"] - ) - - self.assertListEqual( - list(df.first()), - [2, 3.0, "FOO", "foo", "foo", 3, 1.0] - ) - - def test_udf_wrapper(self): - from pyspark.sql.functions import udf - from pyspark.sql.types import IntegerType - - def f(x): - """Identity""" - return x - - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - class F(object): - """Identity""" - def __call__(self, x): - return x - - f = F() - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - f = functools.partial(f, x=1) - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - def test_validate_column_types(self): - from pyspark.sql.functions import udf, to_json - from pyspark.sql.column import _to_java_column - - self.assertTrue("Column" in _to_java_column("a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: _to_java_column(1)) - - class A(): - pass - - self.assertRaises(TypeError, lambda: _to_java_column(A())) - self.assertRaises(TypeError, lambda: _to_java_column([])) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: udf(lambda x: x)(None)) - self.assertRaises(TypeError, lambda: to_json(1)) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.spark.read.json(rdd) - df.count() - df.collect() - df.schema - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist(True) - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - with self.tempView("temp"): - df.createOrReplaceTempView("temp") - df = self.spark.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_infer_schema_to_local(self): - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - df = self.spark.createDataFrame(input) - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_apply_schema_to_dict_and_rows(self): - schema = StructType().add("b", StringType()).add("a", IntegerType()) - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - for verify in [False, True]: - df = self.spark.createDataFrame(input, schema, verifySchema=verify) - df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(10, df3.count()) - input = [Row(a=x, b=str(x)) for x in range(10)] - df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) - self.assertEqual(10, df4.count()) - - def test_create_dataframe_schema_mismatch(self): - input = [Row(a=1)] - rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) - schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.spark.createDataFrame(rdd, schema) - self.assertRaises(Exception, lambda: df.show()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.rdd.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.rdd.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.rdd.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}, s=None), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - self.assertEqual([], df.rdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test"): - df.createOrReplaceTempView("test") - result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test2"): - df2.createOrReplaceTempView("test2") - result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_infer_schema_specification(self): - from decimal import Decimal - - class A(object): - def __init__(self): - self.a = 1 - - data = [ - True, - 1, - "a", - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - array.array("d", [1]), - [1], - (1, ), - {"a": 1}, - bytearray(1), - Decimal(1), - Row(a=1), - Row("a")(1), - A(), - ] - - df = self.spark.createDataFrame([data]) - actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) - expected = [ - 'boolean', - 'bigint', - 'string', - 'string', - 'date', - 'timestamp', - 'double', - 'array<double>', - 'array<bigint>', - 'struct<_1:bigint>', - 'map<string,bigint>', - 'binary', - 'decimal(38,18)', - 'struct<a:bigint>', - 'struct<a:bigint>', - 'struct<a:bigint>', - ] - self.assertEqual(actual, expected) - - actual = list(df.first()) - expected = [ - True, - 1, - 'a', - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - [1.0], - [1], - Row(_1=1), - {"a": 1}, - bytearray(b'\x00'), - Decimal('1.000000000000000000'), - Row(a=1), - Row(a=1), - Row(a=1), - ] - self.assertEqual(actual, expected) - - def test_infer_schema_not_enough_names(self): - df = self.spark.createDataFrame([["a", "b"]], ["col1"]) - self.assertEqual(df.columns, ['col1', '_2']) - - def test_infer_schema_fails(self): - with self.assertRaisesRegexp(TypeError, 'field a'): - self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), - schema=["a", "b"], samplingRatio=0.99) - - def test_infer_nested_schema(self): - NestedRow = Row("f1", "f2") - nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), - NestedRow([2, 3], {"row2": 2.0})]) - df = self.spark.createDataFrame(nestedRdd1) - self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) - - nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), - NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.spark.createDataFrame(nestedRdd2) - self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) - - from collections import namedtuple - CustomRow = namedtuple('CustomRow', 'field1 field2') - rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), - CustomRow(field1=2, field2="row2"), - CustomRow(field1=3, field2="row3")]) - df = self.spark.createDataFrame(rdd) - self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) - - def test_create_dataframe_from_dict_respects_schema(self): - df = self.spark.createDataFrame([{'a': 1}], ["b"]) - self.assertEqual(df.columns, ['b']) - - def test_create_dataframe_from_objects(self): - data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.spark.createDataFrame(data) - self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) - self.assertEqual(df.first(), Row(key=1, value="1")) - - def test_select_null_literal(self): - df = self.spark.sql("select null as col") - self.assertEqual(Row(col=None), df.first()) - - def test_apply_schema(self): - from datetime import date, datetime - rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, - date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), - {"a": 1}, (2,), [1, 2, 3], None)]) - schema = StructType([ - StructField("byte1", ByteType(), False), - StructField("byte2", ByteType(), False), - StructField("short1", ShortType(), False), - StructField("short2", ShortType(), False), - StructField("int1", IntegerType(), False), - StructField("float1", FloatType(), False), - StructField("date1", DateType(), False), - StructField("time1", TimestampType(), False), - StructField("map1", MapType(StringType(), IntegerType(), False), False), - StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), - StructField("list1", ArrayType(ByteType(), False), False), - StructField("null1", DoubleType(), True)]) - df = self.spark.createDataFrame(rdd, schema) - results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, - x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) - r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), - datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - self.assertEqual(r, results.first()) - - with self.tempView("table2"): - df.createOrReplaceTempView("table2") - r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() - - self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - df = self.sc.parallelize(d).toDF() - k, v = list(df.head().m.items())[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - df = self.sc.parallelize([row]).toDF() - - with self.tempView("test"): - df.createOrReplaceTempView("test") - row = self.spark.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - - def check_datatype(datatype): - pickled = pickle.loads(pickle.dumps(datatype)) - assert datatype == pickled - scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) - python_datatype = _parse_datatype_json_string(scala_datatype.json()) - assert datatype == python_datatype - - check_datatype(ExamplePointUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - check_datatype(structtype_with_udt) - p = ExamplePoint(1.0, 2.0) - self.assertEqual(_infer_type(p), ExamplePointUDT()) - _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) - self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) - - check_datatype(PythonOnlyUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - check_datatype(structtype_with_udt) - p = PythonOnlyPoint(1.0, 2.0) - self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) - self.assertRaises( - ValueError, - lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) - - def test_simple_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - def test_nested_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], - schema=schema) - df.collect() - - schema = StructType().add("key", LongType()).add("val", - MapType(LongType(), PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], - schema=schema) - df.collect() - - def test_complex_nested_udt_in_df(self): - from pyspark.sql.functions import udf - - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - gd = df.groupby("key").agg({"val": "collect_list"}) - gd.collect() - udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) - gd.select(udf(*gd)).collect() - - def test_udt_with_none(self): - df = self.spark.range(0, 10, 1, 1) - - def myudf(x): - if x > 0: - return PythonOnlyPoint(float(x), float(x)) - - self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) - rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] - self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) - - def test_nonparam_udf_with_aggregate(self): - import pyspark.sql.functions as f - - df = self.spark.createDataFrame([(1, 2), (1, 2)]) - f_udf = f.udf(lambda: "const_str") - rows = df.distinct().withColumn("a", f_udf()).collect() - self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) - - def test_infer_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), PythonOnlyUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = (1.0, PythonOnlyPoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) - self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) - self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.write.parquet(output_dir) - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - df0.write.parquet(output_dir, mode='overwrite') - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_union_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row1 = (1.0, ExamplePoint(1.0, 2.0)) - row2 = (2.0, ExamplePoint(3.0, 4.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df1 = self.spark.createDataFrame([row1], schema) - df2 = self.spark.createDataFrame([row2], schema) - - result = df1.union(df2).orderBy("label").collect() - self.assertEqual( - result, - [ - Row(label=1.0, point=ExamplePoint(1.0, 2.0)), - Row(label=2.0, point=ExamplePoint(3.0, 4.0)) - ] - ) - - def test_cast_to_string_with_udt(self): - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - from pyspark.sql.functions import col - row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) - schema = StructType([StructField("point", ExamplePointUDT(), False), - StructField("pypoint", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - - result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() - self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) - - def test_column_operators(self): - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbool = (ci & ci), (ci | ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - self.assertRaisesRegexp(ValueError, - "Cannot apply 'in' operator against a column", - lambda: 1 in cs) - - def test_column_getitem(self): - from pyspark.sql.functions import col - - self.assertIsInstance(col("foo")[1:3], Column) - self.assertIsInstance(col("foo")[0], Column) - self.assertIsInstance(col("foo")["bar"], Column) - self.assertRaises(ValueError, lambda: col("foo")[0:10:2]) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_freqItems(self): - vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] - df = self.sc.parallelize(vals).toDF() - items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] - self.assertTrue(1 in items[0]) - self.assertTrue(-2.0 in items[1]) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import functions - self.assertEqual((0, u'99'), - tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) - self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) - - def test_first_last_ignorenulls(self): - from pyspark.sql import functions - df = self.spark.range(0, 100) - df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) - df3 = df2.select(functions.first(df2.id, False).alias('a'), - functions.first(df2.id, True).alias('b'), - functions.last(df2.id, False).alias('c'), - functions.last(df2.id, True).alias('d')) - self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) - - def test_approxQuantile(self): - df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - for f in ["a", u"a"]: - aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) - self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqs, list)) - self.assertEqual(len(aqs), 2) - self.assertTrue(isinstance(aqs[0], list)) - self.assertEqual(len(aqs[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[0])) - self.assertTrue(isinstance(aqs[1], list)) - self.assertEqual(len(aqs[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqt, list)) - self.assertEqual(len(aqt), 2) - self.assertTrue(isinstance(aqt[0], list)) - self.assertEqual(len(aqt[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[0])) - self.assertTrue(isinstance(aqt[1], list)) - self.assertEqual(len(aqt[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[1])) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) - - def test_corr(self): - import math - df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr(u"a", "b") - self.assertTrue(abs(corr - 0.95734012) < 1e-6) - - def test_sampleby(self): - df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() - sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) - self.assertTrue(sampled.count() == 3) - - def test_cov(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov(u"a", "b") - self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) - - def test_crosstab(self): - df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab(u"a", "b").collect() - ct = sorted(ct, key=lambda x: x[0]) - for i, row in enumerate(ct): - self.assertEqual(row[0], str(i)) - self.assertTrue(row[1], 1) - self.assertTrue(row[2], 1) - - def test_math_functions(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - from pyspark.sql import functions - import math - - def get_values(l): - return [j[0] for j in l] - - def assert_close(a, b): - c = get_values(b) - diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] - return sum(diff) == len(a) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos(df.a)).collect()) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos("a")).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df.a)).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df['a'])).collect()) - assert_close([math.pow(i, 2 * i) for i in range(10)], - df.select(functions.pow(df.a, df.b)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2.0)).collect()) - assert_close([math.hypot(i, 2 * i) for i in range(10)], - df.select(functions.hypot(df.a, df.b)).collect()) - - def test_rand_functions(self): - df = self.df - from pyspark.sql import functions - rnd = df.select('key', functions.rand()).collect() - for row in rnd: - assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] - rndn = df.select('key', functions.randn(5)).collect() - for row in rndn: - assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] - - # If the specified seed is 0, we should use it. - # https://issues.apache.org/jira/browse/SPARK-9691 - rnd1 = df.select('key', functions.rand(0)).collect() - rnd2 = df.select('key', functions.rand(0)).collect() - self.assertEqual(sorted(rnd1), sorted(rnd2)) - - rndn1 = df.select('key', functions.randn(0)).collect() - rndn2 = df.select('key', functions.randn(0)).collect() - self.assertEqual(sorted(rndn1), sorted(rndn2)) - - def test_string_functions(self): - from pyspark.sql.functions import col, lit - df = self.spark.createDataFrame([['nick']], schema=['name']) - self.assertRaisesRegexp( - TypeError, - "must be the same type", - lambda: df.select(col('name').substr(0, lit(1)))) - if sys.version_info.major == 2: - self.assertRaises( - TypeError, - lambda: df.select(col('name').substr(long(0), long(1)))) - - def test_array_contains_function(self): - from pyspark.sql.functions import array_contains - - df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) - actual = df.select(array_contains(df.data, "1").alias('b')).collect() - self.assertEqual([Row(b=True), Row(b=False)], actual) - - def test_between_function(self): - df = self.sc.parallelize([ - Row(a=1, b=2, c=3), - Row(a=2, b=1, c=3), - Row(a=4, b=1, c=4)]).toDF() - self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], - df.filter(df.a.between(df.b, df.c)).collect()) - - def test_struct_type(self): - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - # Catch exception raised during improper construction - self.assertRaises(ValueError, lambda: StructType().add("name")) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - for field in struct1: - self.assertIsInstance(field, StructField) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertEqual(len(struct1), 2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertIs(struct1["f1"], struct1.fields[0]) - self.assertIs(struct1[0], struct1.fields[0]) - self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - self.assertRaises(KeyError, lambda: struct1["f9"]) - self.assertRaises(IndexError, lambda: struct1[9]) - self.assertRaises(TypeError, lambda: struct1[9.9]) - - def test_parse_datatype_string(self): - from pyspark.sql.types import _all_atomic_types, _parse_datatype_string - for k, t in _all_atomic_types.items(): - if t != NullType: - self.assertEqual(t(), _parse_datatype_string(k)) - self.assertEqual(IntegerType(), _parse_datatype_string("int")) - self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) - self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) - self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) - self.assertEqual( - ArrayType(IntegerType()), - _parse_datatype_string("array<int >")) - self.assertEqual( - MapType(IntegerType(), DoubleType()), - _parse_datatype_string("map< int, double >")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("struct<a:int, c:double >")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a:int, c:double")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a INT, c DOUBLE")) - - def test_metadata_null(self): - schema = StructType([StructField("f1", StringType(), True, None), - StructField("f2", StringType(), True, {'a': None})]) - rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.spark.createDataFrame(rdd, schema) - - def test_save_and_load(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.json(tmpPath, "overwrite") - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.save(format="json", mode="overwrite", path=tmpPath, - noUse="this options will not be used in save.") - actual = self.spark.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - csvpath = os.path.join(tempfile.mkdtemp(), 'data') - df.write.option('quote', None).format('csv').save(csvpath) - - shutil.rmtree(tmpPath) - - def test_save_and_load_builder(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.mode("overwrite").json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ - .option("noUse", "this option will not be used in save.")\ - .format("json").save(path=tmpPath) - actual =\ - self.spark.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) - - def test_stream_trigger(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - - # Should take at least one arg - try: - df.writeStream.trigger() - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(once=True, processingTime='5 seconds') - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') - except ValueError: - pass - - # Should take only keyword args - try: - df.writeStream.trigger('5 seconds') - self.fail("Should have thrown an exception") - except TypeError: - pass - - def test_stream_read_options(self): - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream\ - .format('text')\ - .option('path', 'python/test_support/sql/streaming')\ - .schema(schema)\ - .load() - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct<data:string>") - - def test_stream_read_options_overwrite(self): - bad_schema = StructType([StructField("test", IntegerType(), False)]) - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ - .schema(bad_schema)\ - .load(path='python/test_support/sql/streaming', schema=schema, format='text') - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct<data:string>") - - def test_stream_save_options(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ - .withColumn('id', lit(1)) - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_save_options_overwrite(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - fake1 = os.path.join(tmpPath, 'fake1') - fake2 = os.path.join(tmpPath, 'fake2') - q = df.writeStream.option('checkpointLocation', fake1)\ - .format('memory').option('path', fake2) \ - .queryName('fake_query').outputMode('append') \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - self.assertFalse(os.path.isdir(fake1)) # should not have been created - self.assertFalse(os.path.isdir(fake2)) # should not have been created - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_status_and_progress(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - - def func(x): - time.sleep(1) - return x - - from pyspark.sql.functions import col, udf - sleep_udf = udf(func) - - # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there - # were no updates. - q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - # "lastProgress" will return None in most cases. However, as it may be flaky when - # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" - # may throw error with a high chance and make this test flaky, so we should still be - # able to detect broken codes. - q.lastProgress - - q.processAllAvailable() - lastProgress = q.lastProgress - recentProgress = q.recentProgress - status = q.status - self.assertEqual(lastProgress['name'], q.name) - self.assertEqual(lastProgress['id'], q.id) - self.assertTrue(any(p == lastProgress for p in recentProgress)) - self.assertTrue( - "message" in status and - "isDataAvailable" in status and - "isTriggerActive" in status) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_await_termination(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream\ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - self.assertTrue(q.isActive) - try: - q.awaitTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - # test should take at least 2 seconds - res = q.awaitTermination(2.6) - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_exception(self): - sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - sq = sdf.writeStream.format('memory'
<TRUNCATED> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org