Repository: spark Updated Branches: refs/heads/master f26cd1881 -> a7a331df6
http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_udf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py new file mode 100644 index 0000000..630b215 --- /dev/null +++ b/python/pyspark/sql/tests/test_udf.py @@ -0,0 +1,654 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import pydoc +import shutil +import tempfile +import unittest + +from pyspark import SparkContext +from pyspark.sql import SparkSession, Column, Row +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message +from pyspark.tests import QuietTest + + +class UDFTests(ReusedSQLTestCase): + + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.spark.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.spark.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf(self): + self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + + def test_udf2(self): + with self.tempView("test"): + self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ + .createOrReplaceTempView("test") + [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf3(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) + + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + + def test_chained_udf(self): + self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.spark.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.spark.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + + def test_single_udf_with_repeated_argument(self): + # regression test for SPARK-20685 + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + row = self.spark.sql("SELECT add(1, 1)").first() + self.assertEqual(tuple(row), (2, )) + + def test_multiple_udfs(self): + self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.spark.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + + def test_udf_in_filter_on_top_of_outer_join(self): + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(a=1)]) + df = left.join(right, on='a', how='left_outer') + df = df.withColumn('b', udf(lambda x: 'x')(df.a)) + self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + + def test_udf_in_filter_on_top_of_join(self): + # regression test for SPARK-18589 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.crossJoin(right).filter(f("a", "b")) + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b")) + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b"), "leftsemi") + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_and_common_filter_in_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1]) + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + + def test_udf_and_common_filter_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_not_supported_in_join_condition(self): + # regression test for SPARK-25314 + # test python udf is not supported in join type besides left_semi and inner join. + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + + def runWithJoinType(join_type, type_string): + with self.assertRaisesRegexp( + AnalysisException, + 'Using PythonUDF.*%s is not supported.' % type_string): + left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() + runWithJoinType("full", "FullOuter") + runWithJoinType("left", "LeftOuter") + runWithJoinType("right", "RightOuter") + runWithJoinType("leftanti", "LeftAnti") + + def test_udf_without_arguments(self): + self.spark.catalog.registerFunction("foo", lambda: "bar") + [row] = self.spark.sql("SELECT foo()").collect() + self.assertEqual(row[0], "bar") + + def test_udf_with_array_type(self): + with self.tempView("test"): + d = [Row(l=list(range(3)), d={"key": list(range(5))})] + rdd = self.sc.parallelize(d) + self.spark.createDataFrame(rdd).createOrReplaceTempView("test") + self.spark.catalog.registerFunction( + "copylist", lambda l: list(l), ArrayType(IntegerType())) + self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(list(range(3)), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.spark.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.spark.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_udf_with_filter_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a < 2, BooleanType()) + sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) + self.assertEqual(sel.collect(), [Row(key=1, value='1')]) + + def test_udf_with_aggregate_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col, sum + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + + def test_udf_in_generate(self): + from pyspark.sql.functions import udf, explode + df = self.spark.range(5) + f = udf(lambda x: list(range(x)), ArrayType(LongType())) + row = df.select(explode(f(*df))).groupBy().sum().first() + self.assertEqual(row[0], 10) + + df = self.spark.range(3) + res = df.select("id", explode(f(df.id))).collect() + self.assertEqual(res[0][0], 1) + self.assertEqual(res[0][1], 0) + self.assertEqual(res[1][0], 2) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 2) + self.assertEqual(res[2][1], 1) + + range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) + res = df.select("id", explode(range_udf(df.id))).collect() + self.assertEqual(res[0][0], 0) + self.assertEqual(res[0][1], -1) + self.assertEqual(res[1][0], 0) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 1) + self.assertEqual(res[2][1], 0) + self.assertEqual(res[3][0], 1) + self.assertEqual(res[3][1], 1) + + def test_udf_with_order_by_and_limit(self): + from pyspark.sql.functions import udf + my_copy = udf(lambda x: x, IntegerType()) + df = self.spark.range(10).orderBy("id") + res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) + self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + + def test_udf_with_input_file_name(self): + from pyspark.sql.functions import udf, input_file_name + sourceFile = udf(lambda path: path, StringType()) + filePath = "python/test_support/sql/people1.json" + row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() + self.assertTrue(row[0].find("people1.json") != -1) + + def test_udf_with_input_file_name_for_hadooprdd(self): + from pyspark.sql.functions import udf, input_file_name + + def filename(path): + return path + + sameText = udf(filename, StringType()) + + rdd = self.sc.textFile('python/test_support/sql/people.json') + df = self.spark.read.json(rdd).select(input_file_name().alias('file')) + row = df.select(sameText(df['file'])).first() + self.assertTrue(row[0].find("people.json") != -1) + + rdd2 = self.sc.newAPIHadoopFile( + 'python/test_support/sql/people.json', + 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', + 'org.apache.hadoop.io.LongWritable', + 'org.apache.hadoop.io.Text') + + df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) + row2 = df2.select(sameText(df2['file'])).first() + self.assertTrue(row2[0].find("people.json") != -1) + + def test_udf_defers_judf_initialization(self): + # This is separate of UDFInitializationTests + # to avoid context initialization + # when udf is called + + from pyspark.sql.functions import UserDefinedFunction + + f = UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + f._judf_placeholder, + "judf should not be initialized before the first call." + ) + + self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") + + self.assertIsNotNone( + f._judf_placeholder, + "judf should be initialized after UDF has been called." + ) + + def test_udf_with_string_return_type(self): + from pyspark.sql.functions import UserDefinedFunction + + add_one = UserDefinedFunction(lambda x: x + 1, "integer") + make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>") + make_array = UserDefinedFunction( + lambda x: [float(x) for x in range(x, x + 3)], "array<double>") + + expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) + actual = (self.spark.range(1, 2).toDF("x") + .select(add_one("x"), make_pair("x"), make_array("x")) + .first()) + + self.assertTupleEqual(expected, actual) + + def test_udf_shouldnt_accept_noncallable_object(self): + from pyspark.sql.functions import UserDefinedFunction + + non_callable = None + self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) + + def test_udf_with_decorator(self): + from pyspark.sql.functions import lit, udf + from pyspark.sql.types import IntegerType, DoubleType + + @udf(IntegerType()) + def add_one(x): + if x is not None: + return x + 1 + + @udf(returnType=DoubleType()) + def add_two(x): + if x is not None: + return float(x + 2) + + @udf + def to_upper(x): + if x is not None: + return x.upper() + + @udf() + def to_lower(x): + if x is not None: + return x.lower() + + @udf + def substr(x, start, end): + if x is not None: + return x[start:end] + + @udf("long") + def trunc(x): + return int(x) + + @udf(returnType="double") + def as_double(x): + return float(x) + + df = ( + self.spark + .createDataFrame( + [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) + .select( + add_one("one"), add_two("one"), + to_upper("Foo"), to_lower("Foo"), + substr("foobar", lit(0), lit(3)), + trunc("float"), as_double("one"))) + + self.assertListEqual( + [tpe for _, tpe in df.dtypes], + ["int", "double", "string", "string", "string", "bigint", "double"] + ) + + self.assertListEqual( + list(df.first()), + [2, 3.0, "FOO", "foo", "foo", 3, 1.0] + ) + + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + class F(object): + """Identity""" + def __call__(self, x): + return x + + f = F() + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + f = functools.partial(f, x=1) + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + + # SPARK-24721 + @unittest.skipIf(not test_compiled, test_not_compiled_message) + def test_datasource_with_udf(self): + from pyspark.sql.functions import udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) + + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) + + test_accum = self.sc.accumulator(0) + + def first_udf(x): + test_accum.add(1) + return x + + def second_udf(x): + test_accum.add(100) + return x + + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + + +class UDFInitializationTests(unittest.TestCase): + def tearDown(self): + if SparkSession._instantiatedSession is not None: + SparkSession._instantiatedSession.stop() + + if SparkContext._active_spark_context is not None: + SparkContext._active_spark_context.stop() + + def test_udf_init_shouldnt_initialize_context(self): + from pyspark.sql.functions import UserDefinedFunction + + UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + SparkContext._active_spark_context, + "SparkContext shouldn't be initialized when UserDefinedFunction is created." + ) + self.assertIsNone( + SparkSession._instantiatedSession, + "SparkSession shouldn't be initialized when UserDefinedFunction is created." + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_udf import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_utils.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py new file mode 100644 index 0000000..63a8614 --- /dev/null +++ b/python/pyspark/sql/tests/test_utils.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.functions import sha2 +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class UtilsTests(ReusedSQLTestCase): + + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.spark.sql("abc")) + + def test_capture_illegalargument_exception(self): + self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", + lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) + self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", + lambda: df.select(sha2(df.a, 1024)).collect()) + try: + df.select(sha2(df.a, 1024)).collect() + except IllegalArgumentException as e: + self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") + self.assertRegexpMatches(e.stackTrace, + "org.apache.spark.sql.functions") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_utils import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/testing/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py new file mode 100644 index 0000000..12bdf0d --- /dev/null +++ b/python/pyspark/testing/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/testing/sqlutils.py ---------------------------------------------------------------------- diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py new file mode 100644 index 0000000..3951776 --- /dev/null +++ b/python/pyspark/testing/sqlutils.py @@ -0,0 +1,268 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import os +import shutil +import tempfile +from contextlib import contextmanager + +from pyspark.sql import SparkSession +from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row +from pyspark.tests import ReusedPySparkTestCase +from pyspark.util import _exception_message + + +pandas_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + # If Pandas version requirement is not satisfied, skip related tests. + pandas_requirement_message = _exception_message(e) + +pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + # If Arrow version requirement is not satisfied, skip related tests. + pyarrow_requirement_message = _exception_message(e) + +test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + test_not_compiled_message = _exception_message(e) + +have_pandas = pandas_requirement_message is None +have_pyarrow = pyarrow_requirement_message is None +test_compiled = test_not_compiled_message is None + + +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) + + def utcoffset(self, dt): + return self.ZERO + + def dst(self, dt): + return self.ZERO + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.sql.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, self.__class__) and \ + other.x == self.x and other.y == self.y + + +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ + + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + + @contextmanager + def database(self, *databases): + """ + A convenient context manager to test with some specific databases. This drops the given + databases if it exists and sets current database to "default" when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for db in databases: + self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) + self.spark.catalog.setCurrentDatabase("default") + + @contextmanager + def table(self, *tables): + """ + A convenient context manager to test with some specific tables. This drops the given tables + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for t in tables: + self.spark.sql("DROP TABLE IF EXISTS %s" % t) + + @contextmanager + def tempView(self, *views): + """ + A convenient context manager to test with some specific views. This drops the given views + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for v in views: + self.spark.catalog.dropTempView(v) + + @contextmanager + def function(self, *functions): + """ + A convenient context manager to test with some specific functions. This drops the given + functions if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for f in functions: + self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) + + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + super(ReusedSQLTestCase, cls).setUpClass() + cls.spark = SparkSession(cls.sc) + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.spark.createDataFrame(cls.testData) + + @classmethod + def tearDownClass(cls): + super(ReusedSQLTestCase, cls).tearDownClass() + cls.spark.stop() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/run-tests.py ---------------------------------------------------------------------- diff --git a/python/run-tests.py b/python/run-tests.py index ccbdfac..4430574 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -251,8 +251,9 @@ def main(): for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: - if test_goal in ('pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests'): + heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', + 'pyspark.tests', 'pyspark.sql.tests'] + if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: priority = 100 http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8bab7e1..7beac16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -45,7 +45,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointUDT" override def serialize(p: ExamplePoint): GenericArrayData = { val output = new Array[Any](2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org