Repository: spark
Updated Branches:
  refs/heads/master f26cd1881 -> a7a331df6


http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
new file mode 100644
index 0000000..630b215
--- /dev/null
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -0,0 +1,654 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+import pydoc
+import shutil
+import tempfile
+import unittest
+
+from pyspark import SparkContext
+from pyspark.sql import SparkSession, Column, Row
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.types import *
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, 
test_not_compiled_message
+from pyspark.tests import QuietTest
+
+
+class UDFTests(ReusedSQLTestCase):
+
+    def test_udf_with_callable(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.spark.createDataFrame(rdd)
+
+        class PlusFour:
+            def __call__(self, col):
+                if col is not None:
+                    return col + 4
+
+        call = PlusFour()
+        pudf = UserDefinedFunction(call, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+    def test_udf_with_partial_function(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.spark.createDataFrame(rdd)
+
+        def some_func(col, param):
+            if col is not None:
+                return col + param
+
+        pfunc = functools.partial(some_func, param=4)
+        pudf = UserDefinedFunction(pfunc, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+    def test_udf(self):
+        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + 
y, IntegerType())
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+        # This is to check if a deprecated 'SQLContext.registerFunction' can 
call its alias.
+        sqlContext = self.spark._wrapped
+        sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
+        [row] = sqlContext.sql("SELECT oneArg('test')").collect()
+        self.assertEqual(row[0], 4)
+
+    def test_udf2(self):
+        with self.tempView("test"):
+            self.spark.catalog.registerFunction("strlen", lambda string: 
len(string), IntegerType())
+            self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
+                .createOrReplaceTempView("test")
+            [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) 
> 1").collect()
+            self.assertEqual(4, res[0])
+
+    def test_udf3(self):
+        two_args = self.spark.catalog.registerFunction(
+            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
+        self.assertEqual(two_args.deterministic, True)
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], u'5')
+
+    def test_udf_registration_return_type_none(self):
+        two_args = self.spark.catalog.registerFunction(
+            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, 
"integer"), None)
+        self.assertEqual(two_args.deterministic, True)
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+    def test_udf_registration_return_type_not_none(self):
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
+                self.spark.catalog.registerFunction(
+                    "f", UserDefinedFunction(lambda x, y: len(x) + y, 
StringType()), StringType())
+
+    def test_nondeterministic_udf(self):
+        # Test that nondeterministic UDFs are evaluated only once in chained 
UDF evaluations
+        from pyspark.sql.functions import udf
+        import random
+        udf_random_col = udf(lambda: int(100 * random.random()), 
IntegerType()).asNondeterministic()
+        self.assertEqual(udf_random_col.deterministic, False)
+        df = 
self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+        udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+        [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+        self.assertEqual(row[0] + 10, row[1])
+
+    def test_nondeterministic_udf2(self):
+        import random
+        from pyspark.sql.functions import udf
+        random_udf = udf(lambda: random.randint(6, 6), 
IntegerType()).asNondeterministic()
+        self.assertEqual(random_udf.deterministic, False)
+        random_udf1 = self.spark.catalog.registerFunction("randInt", 
random_udf)
+        self.assertEqual(random_udf1.deterministic, False)
+        [row] = self.spark.sql("SELECT randInt()").collect()
+        self.assertEqual(row[0], 6)
+        [row] = self.spark.range(1).select(random_udf1()).collect()
+        self.assertEqual(row[0], 6)
+        [row] = self.spark.range(1).select(random_udf()).collect()
+        self.assertEqual(row[0], 6)
+        # render_doc() reproduces the help() exception without printing output
+        pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+        pydoc.render_doc(random_udf)
+        pydoc.render_doc(random_udf1)
+        pydoc.render_doc(udf(lambda x: x).asNondeterministic)
+
+    def test_nondeterministic_udf3(self):
+        # regression test for SPARK-23233
+        from pyspark.sql.functions import udf
+        f = udf(lambda x: x)
+        # Here we cache the JVM UDF instance.
+        self.spark.range(1).select(f("id"))
+        # This should reset the cache to set the deterministic status 
correctly.
+        f = f.asNondeterministic()
+        # Check the deterministic status of udf.
+        df = self.spark.range(1).select(f("id"))
+        deterministic = 
df._jdf.logicalPlan().projectList().head().deterministic()
+        self.assertFalse(deterministic)
+
+    def test_nondeterministic_udf_in_aggregate(self):
+        from pyspark.sql.functions import udf, sum
+        import random
+        udf_random_col = udf(lambda: int(100 * random.random()), 
'int').asNondeterministic()
+        df = self.spark.range(10)
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(AnalysisException, 
"nondeterministic"):
+                df.groupby('id').agg(sum(udf_random_col())).collect()
+            with self.assertRaisesRegexp(AnalysisException, 
"nondeterministic"):
+                df.agg(sum(udf_random_col())).collect()
+
+    def test_chained_udf(self):
+        self.spark.catalog.registerFunction("double", lambda x: x + x, 
IntegerType())
+        [row] = self.spark.sql("SELECT double(1)").collect()
+        self.assertEqual(row[0], 2)
+        [row] = self.spark.sql("SELECT double(double(1))").collect()
+        self.assertEqual(row[0], 4)
+        [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
+        self.assertEqual(row[0], 6)
+
+    def test_single_udf_with_repeated_argument(self):
+        # regression test for SPARK-20685
+        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
+        row = self.spark.sql("SELECT add(1, 1)").first()
+        self.assertEqual(tuple(row), (2, ))
+
+    def test_multiple_udfs(self):
+        self.spark.catalog.registerFunction("double", lambda x: x * 2, 
IntegerType())
+        [row] = self.spark.sql("SELECT double(1), double(2)").collect()
+        self.assertEqual(tuple(row), (2, 4))
+        [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 
2)").collect()
+        self.assertEqual(tuple(row), (4, 12))
+        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
+        [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 
1)").collect()
+        self.assertEqual(tuple(row), (6, 5))
+
+    def test_udf_in_filter_on_top_of_outer_join(self):
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(a=1)])
+        df = left.join(right, on='a', how='left_outer')
+        df = df.withColumn('b', udf(lambda x: 'x')(df.a))
+        self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
+
+    def test_udf_in_filter_on_top_of_join(self):
+        # regression test for SPARK-18589
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.crossJoin(right).filter(f("a", "b"))
+        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+    def test_udf_in_join_condition(self):
+        # regression test for SPARK-25314
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, f("a", "b"))
+        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit 
cartesian product'):
+            df.collect()
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+    def test_udf_in_left_semi_join_condition(self):
+        # regression test for SPARK-25314
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, f("a", "b"), "leftsemi")
+        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit 
cartesian product'):
+            df.collect()
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+    def test_udf_and_common_filter_in_join_condition(self):
+        # regression test for SPARK-25314
+        # test the complex scenario with both udf and common filter
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, [f("a", "b"), left.a1 == right.b1])
+        # do not need spark.sql.crossJoin.enabled=true for udf is not the only 
join condition.
+        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+
+    def test_udf_and_common_filter_in_left_semi_join_condition(self):
+        # regression test for SPARK-25314
+        # test the complex scenario with both udf and common filter
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
+        # do not need spark.sql.crossJoin.enabled=true for udf is not the only 
join condition.
+        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+    def test_udf_not_supported_in_join_condition(self):
+        # regression test for SPARK-25314
+        # test python udf is not supported in join type besides left_semi and 
inner join.
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+
+        def runWithJoinType(join_type, type_string):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    'Using PythonUDF.*%s is not supported.' % type_string):
+                left.join(right, [f("a", "b"), left.a1 == right.b1], 
join_type).collect()
+        runWithJoinType("full", "FullOuter")
+        runWithJoinType("left", "LeftOuter")
+        runWithJoinType("right", "RightOuter")
+        runWithJoinType("leftanti", "LeftAnti")
+
+    def test_udf_without_arguments(self):
+        self.spark.catalog.registerFunction("foo", lambda: "bar")
+        [row] = self.spark.sql("SELECT foo()").collect()
+        self.assertEqual(row[0], "bar")
+
+    def test_udf_with_array_type(self):
+        with self.tempView("test"):
+            d = [Row(l=list(range(3)), d={"key": list(range(5))})]
+            rdd = self.sc.parallelize(d)
+            self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
+            self.spark.catalog.registerFunction(
+                "copylist", lambda l: list(l), ArrayType(IntegerType()))
+            self.spark.catalog.registerFunction("maplen", lambda d: len(d), 
IntegerType())
+            [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from 
test").collect()
+            self.assertEqual(list(range(3)), l1)
+            self.assertEqual(1, l2)
+
+    def test_broadcast_in_udf(self):
+        bar = {"a": "aa", "b": "bb", "c": "abc"}
+        foo = self.sc.broadcast(bar)
+        self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if 
x else '')
+        [res] = self.spark.sql("SELECT MYUDF('c')").collect()
+        self.assertEqual("abc", res[0])
+        [res] = self.spark.sql("SELECT MYUDF('')").collect()
+        self.assertEqual("", res[0])
+
+    def test_udf_with_filter_function(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        from pyspark.sql.functions import udf, col
+        from pyspark.sql.types import BooleanType
+
+        my_filter = udf(lambda a: a < 2, BooleanType())
+        sel = df.select(col("key"), 
col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
+        self.assertEqual(sel.collect(), [Row(key=1, value='1')])
+
+    def test_udf_with_aggregate_function(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        from pyspark.sql.functions import udf, col, sum
+        from pyspark.sql.types import BooleanType
+
+        my_filter = udf(lambda a: a == 1, BooleanType())
+        sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+        self.assertEqual(sel.collect(), [Row(key=1)])
+
+        my_copy = udf(lambda x: x, IntegerType())
+        my_add = udf(lambda a, b: int(a + b), IntegerType())
+        my_strlen = udf(lambda x: len(x), IntegerType())
+        sel = df.groupBy(my_copy(col("key")).alias("k"))\
+            .agg(sum(my_strlen(col("value"))).alias("s"))\
+            .select(my_add(col("k"), col("s")).alias("t"))
+        self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
+    def test_udf_in_generate(self):
+        from pyspark.sql.functions import udf, explode
+        df = self.spark.range(5)
+        f = udf(lambda x: list(range(x)), ArrayType(LongType()))
+        row = df.select(explode(f(*df))).groupBy().sum().first()
+        self.assertEqual(row[0], 10)
+
+        df = self.spark.range(3)
+        res = df.select("id", explode(f(df.id))).collect()
+        self.assertEqual(res[0][0], 1)
+        self.assertEqual(res[0][1], 0)
+        self.assertEqual(res[1][0], 2)
+        self.assertEqual(res[1][1], 0)
+        self.assertEqual(res[2][0], 2)
+        self.assertEqual(res[2][1], 1)
+
+        range_udf = udf(lambda value: list(range(value - 1, value + 1)), 
ArrayType(IntegerType()))
+        res = df.select("id", explode(range_udf(df.id))).collect()
+        self.assertEqual(res[0][0], 0)
+        self.assertEqual(res[0][1], -1)
+        self.assertEqual(res[1][0], 0)
+        self.assertEqual(res[1][1], 0)
+        self.assertEqual(res[2][0], 1)
+        self.assertEqual(res[2][1], 0)
+        self.assertEqual(res[3][0], 1)
+        self.assertEqual(res[3][1], 1)
+
+    def test_udf_with_order_by_and_limit(self):
+        from pyspark.sql.functions import udf
+        my_copy = udf(lambda x: x, IntegerType())
+        df = self.spark.range(10).orderBy("id")
+        res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
+        res.explain(True)
+        self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+
+    def test_udf_registration_returns_udf(self):
+        df = self.spark.range(10)
+        add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
+
+        self.assertListEqual(
+            df.selectExpr("add_three(id) AS plus_three").collect(),
+            df.select(add_three("id").alias("plus_three")).collect()
+        )
+
+        # This is to check if a 'SQLContext.udf' can call its alias.
+        sqlContext = self.spark._wrapped
+        add_four = sqlContext.udf.register("add_four", lambda x: x + 4, 
IntegerType())
+
+        self.assertListEqual(
+            df.selectExpr("add_four(id) AS plus_four").collect(),
+            df.select(add_four("id").alias("plus_four")).collect()
+        )
+
+    def test_non_existed_udf(self):
+        spark = self.spark
+        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udf",
+                                lambda: spark.udf.registerJavaFunction("udf1", 
"non_existed_udf"))
+
+        # This is to check if a deprecated 'SQLContext.registerJavaFunction' 
can call its alias.
+        sqlContext = spark._wrapped
+        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udf",
+                                lambda: 
sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
+
+    def test_non_existed_udaf(self):
+        spark = self.spark
+        self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udaf",
+                                lambda: spark.udf.registerJavaUDAF("udaf1", 
"non_existed_udaf"))
+
+    def test_udf_with_input_file_name(self):
+        from pyspark.sql.functions import udf, input_file_name
+        sourceFile = udf(lambda path: path, StringType())
+        filePath = "python/test_support/sql/people1.json"
+        row = 
self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
+        self.assertTrue(row[0].find("people1.json") != -1)
+
+    def test_udf_with_input_file_name_for_hadooprdd(self):
+        from pyspark.sql.functions import udf, input_file_name
+
+        def filename(path):
+            return path
+
+        sameText = udf(filename, StringType())
+
+        rdd = self.sc.textFile('python/test_support/sql/people.json')
+        df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
+        row = df.select(sameText(df['file'])).first()
+        self.assertTrue(row[0].find("people.json") != -1)
+
+        rdd2 = self.sc.newAPIHadoopFile(
+            'python/test_support/sql/people.json',
+            'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
+            'org.apache.hadoop.io.LongWritable',
+            'org.apache.hadoop.io.Text')
+
+        df2 = 
self.spark.read.json(rdd2).select(input_file_name().alias('file'))
+        row2 = df2.select(sameText(df2['file'])).first()
+        self.assertTrue(row2[0].find("people.json") != -1)
+
+    def test_udf_defers_judf_initialization(self):
+        # This is separate of  UDFInitializationTests
+        # to avoid context initialization
+        # when udf is called
+
+        from pyspark.sql.functions import UserDefinedFunction
+
+        f = UserDefinedFunction(lambda x: x, StringType())
+
+        self.assertIsNone(
+            f._judf_placeholder,
+            "judf should not be initialized before the first call."
+        )
+
+        self.assertIsInstance(f("foo"), Column, "UDF call should return a 
Column.")
+
+        self.assertIsNotNone(
+            f._judf_placeholder,
+            "judf should be initialized after UDF has been called."
+        )
+
+    def test_udf_with_string_return_type(self):
+        from pyspark.sql.functions import UserDefinedFunction
+
+        add_one = UserDefinedFunction(lambda x: x + 1, "integer")
+        make_pair = UserDefinedFunction(lambda x: (-x, x), 
"struct<x:integer,y:integer>")
+        make_array = UserDefinedFunction(
+            lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
+
+        expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
+        actual = (self.spark.range(1, 2).toDF("x")
+                  .select(add_one("x"), make_pair("x"), make_array("x"))
+                  .first())
+
+        self.assertTupleEqual(expected, actual)
+
+    def test_udf_shouldnt_accept_noncallable_object(self):
+        from pyspark.sql.functions import UserDefinedFunction
+
+        non_callable = None
+        self.assertRaises(TypeError, UserDefinedFunction, non_callable, 
StringType())
+
+    def test_udf_with_decorator(self):
+        from pyspark.sql.functions import lit, udf
+        from pyspark.sql.types import IntegerType, DoubleType
+
+        @udf(IntegerType())
+        def add_one(x):
+            if x is not None:
+                return x + 1
+
+        @udf(returnType=DoubleType())
+        def add_two(x):
+            if x is not None:
+                return float(x + 2)
+
+        @udf
+        def to_upper(x):
+            if x is not None:
+                return x.upper()
+
+        @udf()
+        def to_lower(x):
+            if x is not None:
+                return x.lower()
+
+        @udf
+        def substr(x, start, end):
+            if x is not None:
+                return x[start:end]
+
+        @udf("long")
+        def trunc(x):
+            return int(x)
+
+        @udf(returnType="double")
+        def as_double(x):
+            return float(x)
+
+        df = (
+            self.spark
+                .createDataFrame(
+                    [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", 
"float"))
+                .select(
+                    add_one("one"), add_two("one"),
+                    to_upper("Foo"), to_lower("Foo"),
+                    substr("foobar", lit(0), lit(3)),
+                    trunc("float"), as_double("one")))
+
+        self.assertListEqual(
+            [tpe for _, tpe in df.dtypes],
+            ["int", "double", "string", "string", "string", "bigint", "double"]
+        )
+
+        self.assertListEqual(
+            list(df.first()),
+            [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
+        )
+
+    def test_udf_wrapper(self):
+        from pyspark.sql.functions import udf
+        from pyspark.sql.types import IntegerType
+
+        def f(x):
+            """Identity"""
+            return x
+
+        return_type = IntegerType()
+        f_ = udf(f, return_type)
+
+        self.assertTrue(f.__doc__ in f_.__doc__)
+        self.assertEqual(f, f_.func)
+        self.assertEqual(return_type, f_.returnType)
+
+        class F(object):
+            """Identity"""
+            def __call__(self, x):
+                return x
+
+        f = F()
+        return_type = IntegerType()
+        f_ = udf(f, return_type)
+
+        self.assertTrue(f.__doc__ in f_.__doc__)
+        self.assertEqual(f, f_.func)
+        self.assertEqual(return_type, f_.returnType)
+
+        f = functools.partial(f, x=1)
+        return_type = IntegerType()
+        f_ = udf(f, return_type)
+
+        self.assertTrue(f.__doc__ in f_.__doc__)
+        self.assertEqual(f, f_.func)
+        self.assertEqual(return_type, f_.returnType)
+
+    def test_nonparam_udf_with_aggregate(self):
+        import pyspark.sql.functions as f
+
+        df = self.spark.createDataFrame([(1, 2), (1, 2)])
+        f_udf = f.udf(lambda: "const_str")
+        rows = df.distinct().withColumn("a", f_udf()).collect()
+        self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
+
+    # SPARK-24721
+    @unittest.skipIf(not test_compiled, test_not_compiled_message)
+    def test_datasource_with_udf(self):
+        from pyspark.sql.functions import udf, lit, col
+
+        path = tempfile.mkdtemp()
+        shutil.rmtree(path)
+
+        try:
+            
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
+            filesource_df = self.spark.read.option('inferSchema', 
True).csv(path).toDF('i')
+            datasource_df = self.spark.read \
+                .format("org.apache.spark.sql.sources.SimpleScanSource") \
+                .option('from', 0).option('to', 1).load().toDF('i')
+            datasource_v2_df = self.spark.read \
+                .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
+                .load().toDF('i', 'j')
+
+            c1 = udf(lambda x: x + 1, 'int')(lit(1))
+            c2 = udf(lambda x: x + 1, 'int')(col('i'))
+
+            f1 = udf(lambda x: False, 'boolean')(lit(1))
+            f2 = udf(lambda x: False, 'boolean')(col('i'))
+
+            for df in [filesource_df, datasource_df, datasource_v2_df]:
+                result = df.withColumn('c', c1)
+                expected = df.withColumn('c', lit(2))
+                self.assertEquals(expected.collect(), result.collect())
+
+            for df in [filesource_df, datasource_df, datasource_v2_df]:
+                result = df.withColumn('c', c2)
+                expected = df.withColumn('c', col('i') + 1)
+                self.assertEquals(expected.collect(), result.collect())
+
+            for df in [filesource_df, datasource_df, datasource_v2_df]:
+                for f in [f1, f2]:
+                    result = df.filter(f)
+                    self.assertEquals(0, result.count())
+        finally:
+            shutil.rmtree(path)
+
+    # SPARK-25591
+    def test_same_accumulator_in_udfs(self):
+        from pyspark.sql.functions import udf
+
+        data_schema = StructType([StructField("a", IntegerType(), True),
+                                  StructField("b", IntegerType(), True)])
+        data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+        test_accum = self.sc.accumulator(0)
+
+        def first_udf(x):
+            test_accum.add(1)
+            return x
+
+        def second_udf(x):
+            test_accum.add(100)
+            return x
+
+        func_udf = udf(first_udf, IntegerType())
+        func_udf2 = udf(second_udf, IntegerType())
+        data = data.withColumn("out1", func_udf(data["a"]))
+        data = data.withColumn("out2", func_udf2(data["b"]))
+        data.collect()
+        self.assertEqual(test_accum.value, 101)
+
+
+class UDFInitializationTests(unittest.TestCase):
+    def tearDown(self):
+        if SparkSession._instantiatedSession is not None:
+            SparkSession._instantiatedSession.stop()
+
+        if SparkContext._active_spark_context is not None:
+            SparkContext._active_spark_context.stop()
+
+    def test_udf_init_shouldnt_initialize_context(self):
+        from pyspark.sql.functions import UserDefinedFunction
+
+        UserDefinedFunction(lambda x: x, StringType())
+
+        self.assertIsNone(
+            SparkContext._active_spark_context,
+            "SparkContext shouldn't be initialized when UserDefinedFunction is 
created."
+        )
+        self.assertIsNone(
+            SparkSession._instantiatedSession,
+            "SparkSession shouldn't be initialized when UserDefinedFunction is 
created."
+        )
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_udf import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/sql/tests/test_utils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
new file mode 100644
index 0000000..63a8614
--- /dev/null
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.functions import sha2
+from pyspark.sql.utils import AnalysisException, ParseException, 
IllegalArgumentException
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UtilsTests(ReusedSQLTestCase):
+
+    def test_capture_analysis_exception(self):
+        self.assertRaises(AnalysisException, lambda: self.spark.sql("select 
abc"))
+        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + 
b"))
+
+    def test_capture_parse_exception(self):
+        self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
+
+    def test_capture_illegalargument_exception(self):
+        self.assertRaisesRegexp(IllegalArgumentException, "Setting negative 
mapred.reduce.tasks",
+                                lambda: self.spark.sql("SET 
mapred.reduce.tasks=-1"))
+        df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
+        self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the 
permitted values",
+                                lambda: df.select(sha2(df.a, 1024)).collect())
+        try:
+            df.select(sha2(df.a, 1024)).collect()
+        except IllegalArgumentException as e:
+            self.assertRegexpMatches(e.desc, "1024 is not in the permitted 
values")
+            self.assertRegexpMatches(e.stackTrace,
+                                     "org.apache.spark.sql.functions")
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_utils import *
+
+    try:
+        import xmlrunner
+        
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), 
verbosity=2)
+    except ImportError:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/testing/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/testing/__init__.py 
b/python/pyspark/testing/__init__.py
new file mode 100644
index 0000000..12bdf0d
--- /dev/null
+++ b/python/pyspark/testing/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/pyspark/testing/sqlutils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
new file mode 100644
index 0000000..3951776
--- /dev/null
+++ b/python/pyspark/testing/sqlutils.py
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import os
+import shutil
+import tempfile
+from contextlib import contextmanager
+
+from pyspark.sql import SparkSession
+from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
+from pyspark.tests import ReusedPySparkTestCase
+from pyspark.util import _exception_message
+
+
+pandas_requirement_message = None
+try:
+    from pyspark.sql.utils import require_minimum_pandas_version
+    require_minimum_pandas_version()
+except ImportError as e:
+    # If Pandas version requirement is not satisfied, skip related tests.
+    pandas_requirement_message = _exception_message(e)
+
+pyarrow_requirement_message = None
+try:
+    from pyspark.sql.utils import require_minimum_pyarrow_version
+    require_minimum_pyarrow_version()
+except ImportError as e:
+    # If Arrow version requirement is not satisfied, skip related tests.
+    pyarrow_requirement_message = _exception_message(e)
+
+test_not_compiled_message = None
+try:
+    from pyspark.sql.utils import require_test_compiled
+    require_test_compiled()
+except Exception as e:
+    test_not_compiled_message = _exception_message(e)
+
+have_pandas = pandas_requirement_message is None
+have_pyarrow = pyarrow_requirement_message is None
+test_compiled = test_not_compiled_message is None
+
+
+class UTCOffsetTimezone(datetime.tzinfo):
+    """
+    Specifies timezone in UTC offset
+    """
+
+    def __init__(self, offset=0):
+        self.ZERO = datetime.timedelta(hours=offset)
+
+    def utcoffset(self, dt):
+        return self.ZERO
+
+    def dst(self, dt):
+        return self.ZERO
+
+
+class ExamplePointUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return 'pyspark.sql.tests'
+
+    @classmethod
+    def scalaUDT(cls):
+        return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+    """
+    An example class to demonstrate UDT in Scala, Java, and Python.
+    """
+
+    __UDT__ = ExamplePointUDT()
+
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+
+    def __repr__(self):
+        return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+    def __str__(self):
+        return "(%s,%s)" % (self.x, self.y)
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__) and \
+            other.x == self.x and other.y == self.y
+
+
+class PythonOnlyUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return '__main__'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return PythonOnlyPoint(datum[0], datum[1])
+
+    @staticmethod
+    def foo():
+        pass
+
+    @property
+    def props(self):
+        return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+    """
+    An example class to demonstrate UDT in only Python
+    """
+    __UDT__ = PythonOnlyUDT()
+
+
+class MyObject(object):
+    def __init__(self, key, value):
+        self.key = key
+        self.value = value
+
+
+class SQLTestUtils(object):
+    """
+    This util assumes the instance of this to have 'spark' attribute, having a 
spark session.
+    It is usually used with 'ReusedSQLTestCase' class but can be used if you 
feel sure the
+    the implementation of this class has 'spark' attribute.
+    """
+
+    @contextmanager
+    def sql_conf(self, pairs):
+        """
+        A convenient context manager to test some configuration specific 
logic. This sets
+        `value` to the configuration `key` and then restores it back when it 
exits.
+        """
+        assert isinstance(pairs, dict), "pairs should be a dictionary."
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        keys = pairs.keys()
+        new_values = pairs.values()
+        old_values = [self.spark.conf.get(key, None) for key in keys]
+        for key, new_value in zip(keys, new_values):
+            self.spark.conf.set(key, new_value)
+        try:
+            yield
+        finally:
+            for key, old_value in zip(keys, old_values):
+                if old_value is None:
+                    self.spark.conf.unset(key)
+                else:
+                    self.spark.conf.set(key, old_value)
+
+    @contextmanager
+    def database(self, *databases):
+        """
+        A convenient context manager to test with some specific databases. 
This drops the given
+        databases if it exists and sets current database to "default" when it 
exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for db in databases:
+                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
+            self.spark.catalog.setCurrentDatabase("default")
+
+    @contextmanager
+    def table(self, *tables):
+        """
+        A convenient context manager to test with some specific tables. This 
drops the given tables
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for t in tables:
+                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
+
+    @contextmanager
+    def tempView(self, *views):
+        """
+        A convenient context manager to test with some specific views. This 
drops the given views
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for v in views:
+                self.spark.catalog.dropTempView(v)
+
+    @contextmanager
+    def function(self, *functions):
+        """
+        A convenient context manager to test with some specific functions. 
This drops the given
+        functions if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for f in functions:
+                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+
+
+class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        super(ReusedSQLTestCase, cls).setUpClass()
+        cls.spark = SparkSession(cls.sc)
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        cls.df = cls.spark.createDataFrame(cls.testData)
+
+    @classmethod
+    def tearDownClass(cls):
+        super(ReusedSQLTestCase, cls).tearDownClass()
+        cls.spark.stop()
+        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+    def assertPandasEqual(self, expected, result):
+        msg = ("DataFrames are not equal: " +
+               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
+               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
+        self.assertTrue(expected.equals(result), msg=msg)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/python/run-tests.py
----------------------------------------------------------------------
diff --git a/python/run-tests.py b/python/run-tests.py
index ccbdfac..4430574 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -251,8 +251,9 @@ def main():
         for module in modules_to_test:
             if python_implementation not in 
module.blacklisted_python_implementations:
                 for test_goal in module.python_test_goals:
-                    if test_goal in ('pyspark.streaming.tests', 
'pyspark.mllib.tests',
-                                     'pyspark.tests', 'pyspark.sql.tests'):
+                    heavy_tests = ['pyspark.streaming.tests', 
'pyspark.mllib.tests',
+                                   'pyspark.tests', 'pyspark.sql.tests']
+                    if any(map(lambda prefix: test_goal.startswith(prefix), 
heavy_tests)):
                         priority = 0
                     else:
                         priority = 100

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a331df/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 8bab7e1..7beac16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -45,7 +45,7 @@ private[sql] class ExamplePointUDT extends 
UserDefinedType[ExamplePoint] {
 
   override def sqlType: DataType = ArrayType(DoubleType, false)
 
-  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
+  override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointUDT"
 
   override def serialize(p: ExamplePoint): GenericArrayData = {
     val output = new Array[Any](2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to