This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c99463d [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well c99463d is described below commit c99463d4cfd5c70a28fdf89414207955f60c4789 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Wed Mar 20 08:06:10 2019 +0900 [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below: **Before:** ```python >>> from pyspark.sql.functions import lit, ascii >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value")) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/sql/functions.py", line 51, in _ jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__ File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace: py4j.Py4JException: Method ascii([class java.lang.String]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339) at py4j.Gateway.invoke(Gateway.java:276) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) ``` ```python >>> from pyspark.sql.functions import atan2 >>> spark.range(1).select(atan2("id", "id")) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/sql/functions.py", line 78, in _ jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1), ValueError: could not convert string to float: id ``` **After:** ```python >>> from pyspark.sql.functions import lit, ascii >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value")) DataFrame[ascii(value): int] ``` ```python >>> from pyspark.sql.functions import atan2 >>> spark.range(1).select(atan2("id", "id")) DataFrame[ATAN2(id, id): double] ``` Note that, - This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names. - I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions. - There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity. ```python >>> spark.range(1).select(when(lit(True), col("id"))).show() ``` ``` +--------------------------+ |CASE WHEN true THEN id END| +--------------------------+ | 0| +--------------------------+ ``` ```python >>> spark.range(1).select(when(lit(True), "id")).show() ``` ``` +--------------------------+ |CASE WHEN true THEN id END| +--------------------------+ | id| +--------------------------+ ``` This PR also fixes as below: https://github.com/apache/spark/pull/23882 fixed it to: - Rename `_create_function` to `_create_name_function` - Define new `_create_function` to take strings as column names. This PR, I proposes to: - Revert `_create_name_function` name to `_create_function`. - Define new `_create_function_over_column` to take strings as column names. ## How was this patch tested? Some unit tests were added for binary math / string functions. Closes #24121 from HyukjinKwon/SPARK-26979. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/functions.py | 79 +++++++++++++++++++----------- python/pyspark/sql/tests/test_functions.py | 14 +++++- 2 files changed, 64 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3ee485c..0326613 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -30,15 +30,22 @@ if sys.version >= '3': from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \ + _create_column_from_name from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf +# Note to developers: all of PySpark functions here take string as column names whenever possible. +# Namely, if columns are referred as arguments, they can be always both Column or string, +# even though there might be few exceptions for legacy or inevitable reasons. +# If you are fixing other language APIs together, also please note that Scala side is not the case +# since it requires to make every single overridden definition. -def _create_name_function(name, doc=""): - """ Create a function that takes a column name argument, by name""" + +def _create_function(name, doc=""): + """Create a PySpark function by its name""" def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) @@ -48,8 +55,11 @@ def _create_name_function(name, doc=""): return _ -def _create_function(name, doc=""): - """ Create a function that takes a Column object, by name""" +def _create_function_over_column(name, doc=""): + """Similar with `_create_function` but creates a PySpark function that takes a column + (as string as well). This is mainly for PySpark functions to take strings as + column names. + """ def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(_to_java_column(col)) @@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): sc = SparkContext._active_spark_context - # users might write ints for simplicity. This would throw an error on the JVM side. - jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1), - col2._jc if isinstance(col2, Column) else float(col2)) + # For legacy reasons, the arguments here can be implicitly converted into floats, + # if they are not columns or strings. + if isinstance(col1, Column): + arg1 = col1._jc + elif isinstance(col1, basestring): + arg1 = _create_column_from_name(col1) + else: + arg1 = float(col1) + + if isinstance(col2, Column): + arg2 = col2._jc + elif isinstance(col2, basestring): + arg2 = _create_column_from_name(col2) + else: + arg2 = float(col2) + + jc = getattr(sc._jvm.functions, name)(arg1, arg2) return Column(jc) _.__name__ = name _.__doc__ = doc @@ -96,8 +120,7 @@ _lit_doc = """ >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) [Row(height=5, spark_user=True)] """ -_name_functions = { - # name functions take a column name as their argument +_functions = { 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', @@ -105,9 +128,7 @@ _name_functions = { 'desc': 'Returns a sort expression based on the descending order of the given column name.', } -_functions = { - 'upper': 'Converts a string expression to upper case.', - 'lower': 'Converts a string expression to upper case.', +_functions_over_column = { 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', @@ -120,7 +141,7 @@ _functions = { 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } -_functions_1_4 = { +_functions_1_4_over_column = { # unary math functions 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', @@ -155,7 +176,7 @@ _functions_1_4 = { 'bitwiseNOT': 'Computes bitwise not.', } -_name_functions_2_4 = { +_functions_2_4 = { 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + ' column name, and null values return before non-null values.', 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + @@ -186,7 +207,7 @@ _collect_set_doc = """ >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] """ -_functions_1_6 = { +_functions_1_6_over_column = { # unary math functions 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + ' the expression in a group.', @@ -203,7 +224,7 @@ _functions_1_6 = { 'collect_set': _collect_set_doc } -_functions_2_1 = { +_functions_2_1_over_column = { # unary math functions 'degrees': """ Converts an angle measured in radians to an approximately equivalent angle @@ -268,24 +289,24 @@ _window_functions = { _functions_deprecated = { } -for _name, _doc in _name_functions.items(): - globals()[_name] = since(1.3)(_create_name_function(_name, _doc)) for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) -for _name, _doc in _functions_1_4.items(): - globals()[_name] = since(1.4)(_create_function(_name, _doc)) +for _name, _doc in _functions_over_column.items(): + globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc)) +for _name, _doc in _functions_1_4_over_column.items(): + globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc)) for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) -for _name, _doc in _functions_1_6.items(): - globals()[_name] = since(1.6)(_create_function(_name, _doc)) -for _name, _doc in _functions_2_1.items(): - globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _doc in _functions_1_6_over_column.items(): + globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc)) +for _name, _doc in _functions_2_1_over_column.items(): + globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) -for _name, _doc in _name_functions_2_4.items(): - globals()[_name] = since(2.4)(_create_name_function(_name, _doc)) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc @@ -1450,6 +1471,8 @@ def hash(*cols): # ---------------------- String/Binary functions ------------------------------ _string_functions = { + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to lower case.', 'ascii': 'Computes the numeric value of the first character of the string column.', 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', @@ -1460,7 +1483,7 @@ _string_functions = { for _name, _doc in _string_functions.items(): - globals()[_name] = since(1.5)(_create_function(_name, _doc)) + globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index fe66602..b777573 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -129,6 +129,12 @@ class FunctionsTests(ReusedSQLTestCase): df.select(functions.pow(df.a, 2.0)).collect()) assert_close([math.hypot(i, 2 * i) for i in range(10)], df.select(functions.hypot(df.a, df.b)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot("a", u"b")).collect()) + assert_close([math.hypot(i, 2) for i in range(10)], + df.select(functions.hypot("a", 2)).collect()) + assert_close([math.hypot(i, 2) for i in range(10)], + df.select(functions.hypot(df.a, 2)).collect()) def test_rand_functions(self): df = self.df @@ -151,7 +157,8 @@ class FunctionsTests(ReusedSQLTestCase): self.assertEqual(sorted(rndn1), sorted(rndn2)) def test_string_functions(self): - from pyspark.sql.functions import col, lit + from pyspark.sql import functions + from pyspark.sql.functions import col, lit, _string_functions df = self.spark.createDataFrame([['nick']], schema=['name']) self.assertRaisesRegexp( TypeError, @@ -162,6 +169,11 @@ class FunctionsTests(ReusedSQLTestCase): TypeError, lambda: df.select(col('name').substr(long(0), long(1)))) + for name in _string_functions.keys(): + self.assertEqual( + df.select(getattr(functions, name)("name")).first()[0], + df.select(getattr(functions, name)(col("name"))).first()[0]) + def test_array_contains_function(self): from pyspark.sql.functions import array_contains --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org