grundprinzip commented on code in PR #39585: URL: https://github.com/apache/spark/pull/39585#discussion_r1073203970
########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -808,6 +808,59 @@ class SparkConnectPlanner(session: SparkSession) { } } + /** + * Translates a user-defined function from proto to the Catalyst expression. + * + * @param fun + * Proto representation of the function call. + * @return + * Expression. + */ + private def transformScalarInlineUserDefinedFunction( + fun: proto.Expression.ScalarInlineUserDefinedFunction): Expression = { + fun.getFunctionCase match { + case proto.Expression.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + transformPythonUDF(fun) Review Comment: this needs a `case _ => throw` for the case the `oneof` is unknown. ########## python/pyspark/sql/connect/functions.py: ########## @@ -2350,8 +2356,21 @@ def unwrap_udt(col: "ColumnOrName") -> Column: unwrap_udt.__doc__ = pysparkfuncs.unwrap_udt.__doc__ -def udf(*args: Any, **kwargs: Any) -> None: - raise NotImplementedError("udf() is not implemented.") +def udf( + f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None, + returnType: "DataTypeOrString" = StringType(), +) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial( + _create_udf, returnType=return_type, evalType=100 # PythonEvalType.SQL_BATCHED_UDF Review Comment: I see that this class is defined in rdd.py. @HyukjinKwon can we create a PR to move this constant definition out of this file into a more generic place? ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -808,6 +808,59 @@ class SparkConnectPlanner(session: SparkSession) { } } + /** + * Translates a user-defined function from proto to the Catalyst expression. + * + * @param fun + * Proto representation of the function call. + * @return + * Expression. + */ + private def transformScalarInlineUserDefinedFunction( + fun: proto.Expression.ScalarInlineUserDefinedFunction): Expression = { + fun.getFunctionCase match { + case proto.Expression.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + transformPythonUDF(fun) + } + } + + /** + * Translates a Python user-defined function from proto to the Catalyst expression. + * + * @param fun + * Proto representation of the function call. + * @return + * PythonUDF. + */ + private def transformPythonUDF( + fun: proto.Expression.ScalarInlineUserDefinedFunction): PythonUDF = { + val udf = fun.getPythonUDF + PythonUDF( + fun.getFunctionName, + transformPythonFunction(udf), + DataType.parseTypeWithFallback( + schema = udf.getOutputType, + parser = DataType.fromDDL, + fallbackParser = DataType.fromJson) match { + case s: DataType => s + case other => throw InvalidPlanInput(s"Invalid return type $other") + }, + fun.getArgumentsList.asScala.map(transformExpression).toSeq, + udf.getEvalType, + fun.getDeterministic) + } + + private def transformPythonFunction(fun: proto.Expression.PythonUDF): SimplePythonFunction = { Review Comment: I know this might be very nitty, but do you mind leaving a comment for when your using default constructed arguments? Something like: ``` // Empty Environment variables. Maps.newHashMap(), // No imported Python libraries. Lists.newArrayList(), ``` ########## python/pyspark/sql/connect/udf.py: ########## @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined function related classes and functions +""" +import functools +from typing import Callable, Any, TYPE_CHECKING, Optional + +from pyspark import cloudpickle +from pyspark.sql.connect.expressions import ( + ColumnReference, + PythonUDF, + ScalarInlineUserDefinedFunction, +) +from pyspark.sql.connect.column import Column +from pyspark.sql.types import DataType, StringType + + +if TYPE_CHECKING: + from pyspark.sql.connect._typing import ( + ColumnOrName, + DataTypeOrString, + UserDefinedFunctionLike, + ) + from pyspark.sql.types import StringType + + +def _create_udf( + f: Callable[..., Any], + returnType: "DataTypeOrString", + evalType: int, + name: Optional[str] = None, + deterministic: bool = True, +) -> "UserDefinedFunctionLike": + # Set the name of the UserDefinedFunction object to be the name of function f + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic + ) + return udf_obj._wrapped() + + +class UserDefinedFunction: + """ + User defined function in Python + + Notes + ----- + The constructor of this class is not supposed to be directly called. + Use :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf` + to create this instance. + """ + + def __init__( + self, + func: Callable[..., Any], + returnType: "DataTypeOrString" = StringType(), + name: Optional[str] = None, + evalType: int = 100, + deterministic: bool = True, + ): + if not callable(func): + raise TypeError( + "Invalid function: not a function or callable (__call__ is not defined): " + "{0}".format(type(func)) + ) + + if not isinstance(returnType, (DataType, str)): + raise TypeError( + "Invalid return type: returnType should be DataType or str " + "but is {}".format(returnType) + ) + + if not isinstance(evalType, int): + raise TypeError( + "Invalid evaluation type: evalType should be an int but is {}".format(evalType) + ) + + self.func = func + self._returnType = returnType + self._name = name or ( + func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ + ) + self.evalType = evalType + self.deterministic = deterministic + + def __call__(self, *cols: "ColumnOrName") -> Column: + arg_cols = [ + col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols + ] + arg_exprs = [col._expr for col in arg_cols] + data_type_str = ( + self._returnType.json() if isinstance(self._returnType, DataType) else self._returnType + ) + py_udf = PythonUDF( + output_type=data_type_str, + eval_type=self.evalType, + command=cloudpickle.dumps(self.func), + ) + return Column( + ScalarInlineUserDefinedFunction( + function_name=self._name, + deterministic=self.deterministic, + arguments=arg_exprs, + function=py_udf, + ) + ) + + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) + def _wrapped(self) -> "UserDefinedFunctionLike": + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__" and a != "__module__" + ) + + @functools.wraps(self.func, assigned=assignments) + def wrapper(*args: "ColumnOrName") -> Column: + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = ( + self.func.__module__ + if hasattr(self.func, "__module__") + else self.func.__class__.__module__ + ) + + wrapper.func = self.func # type: ignore[attr-defined] + wrapper.returnType = self._returnType # type: ignore[attr-defined] + wrapper.evalType = self.evalType # type: ignore[attr-defined] + wrapper.deterministic = self.deterministic # type: ignore[attr-defined] + wrapper.asNondeterministic = functools.wraps( # type: ignore[attr-defined] + self.asNondeterministic + )(lambda: self.asNondeterministic()._wrapped()) + wrapper._unwrapped = self # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + def asNondeterministic(self) -> "UserDefinedFunction": + """ + Updates UserDefinedFunction to nondeterministic. + + """ + # Here, we explicitly clean the cache to create a JVM UDF instance + # with 'deterministic' updated. See SPARK-23233. Review Comment: is this comment still true? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org