This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a9da92498f0 [SPARK-40538][CONNECT] Improve built-in function support for Python client a9da92498f0 is described below commit a9da92498f0968eab21590845abbf1987ee9f1cd Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Tue Oct 18 20:08:36 2022 +0900 [SPARK-40538][CONNECT] Improve built-in function support for Python client ### What changes were proposed in this pull request? This patch changes the way simple scalar built-in functions are resolved in the Python Spark Connect client. Previously, it was trying to manually load specific functions. With the changes in this patch, the trivial binary operators like `<`, `+`, ... are mapped to their name equivalents in Spark so that the dynamic function lookup works. In addition, it cleans up the Scala planner side to remove the now unnecessary code translating the trivial binary expressions into their equivalent functions. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT, E2E Closes #38270 from grundprinzip/spark-40538. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/sql/connect/dsl/package.scala | 38 ++++++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 25 ++----- .../connect/planner/SparkConnectPlannerSuite.scala | 2 +- .../connect/planner/SparkConnectProtoSuite.scala | 28 ++++++++ python/pyspark/sql/connect/column.py | 80 +++++++++++----------- .../sql/tests/connect/test_connect_basic.py | 12 ++++ .../connect/test_connect_column_expressions.py | 29 ++++++++ .../sql/tests/connect/test_connect_plan_only.py | 2 +- 8 files changed, 156 insertions(+), 60 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 579f190156f..0c392130562 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -92,6 +92,44 @@ package object dsl { .build() } + /** + * Create an unresolved function from name parts. + * + * @param nameParts + * @param args + * @return + * Expression wrapping the unresolved function. + */ + def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = { + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addAllParts(nameParts.asJava) + .addAllArguments(args.asJava)) + .build() + } + + /** + * Creates an UnresolvedFunction from a single identifier. + * + * @param name + * @param args + * @return + * Expression wrapping the unresolved function. + */ + def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = { + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts(name) + .addAllArguments(args.asJava)) + .build() + } + implicit def intToLiteral(i: Int): proto.Expression = proto.Expression .newBuilder() diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 61352c17a23..7ffce908221 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -197,10 +197,6 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { limitExpr = expressions.Literal(limit.getLimit, IntegerType)) } - private def lookupFunction(name: String, args: Seq[Expression]): Expression = { - UnresolvedFunction(Seq(name), args, isDistinct = false) - } - /** * Translates a scalar function from proto to the Catalyst expression. * @@ -211,21 +207,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { - val funName = fun.getPartsList.asScala.mkString(".") - funName match { - case "gt" => - assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.") - expressions.GreaterThan( - transformExpression(fun.getArguments(0)), - transformExpression(fun.getArguments(1))) - case "eq" => - assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.") - expressions.EqualTo( - transformExpression(fun.getArguments(0)), - transformExpression(fun.getArguments(1))) - case _ => - lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq) + if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) { + throw new IllegalArgumentException( + "Function identifier must be passed as sequence of name parts.") } + UnresolvedFunction( + fun.getPartsList.asScala.toSeq, + fun.getArgumentsList.asScala.map(transformExpression).toSeq, + isDistinct = false) } private def transformAlias(alias: proto.Expression.Alias): Expression = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 67518f3bdb1..74788ce5593 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -197,7 +197,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction( proto.Expression.UnresolvedFunction.newBuilder - .addAllParts(Seq("eq").asJava) + .addAllParts(Seq("==").asJava) .addArguments(unresolvedAttribute) .addArguments(unresolvedAttribute) .build()) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index b13e74c2125..ef4b358798e 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } + test("UnresolvedFunction resolution.") { + { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + assertThrows[IllegalArgumentException] { + transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) + } + } + + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform( + connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr)))) + } + + assertThrows[UnsupportedOperationException] { + connectPlan.analyze + } + + val validPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr)))) + } + assert(validPlan.analyze != null) + } + test("Basic filter") { val connectPlan = { import org.apache.spark.sql.connect.dsl.expressions._ diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 55b8f176e05..b291c5fb211 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -26,11 +26,51 @@ if TYPE_CHECKING: import pyspark.sql.connect.proto as proto +def _bin_op( + name: str, doc: str = "binary function", reverse: bool = False +) -> Callable[["ColumnRef", Any], "Expression"]: + def _(self: "ColumnRef", other: Any) -> "Expression": + if isinstance(other, get_args(PrimitiveType)): + other = LiteralExpression(other) + if not reverse: + return ScalarFunctionExpression(name, self, other) + else: + return ScalarFunctionExpression(name, other, self) + + return _ + + class Expression(object): """ Expression base class. """ + __gt__ = _bin_op(">") + __lt__ = _bin_op(">") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __truediv__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+", reverse=True) + __rsub__ = _bin_op("-", reverse=True) + __rmul__ = _bin_op("*", reverse=True) + __rdiv__ = _bin_op("/", reverse=True) + __rtruediv__ = _bin_op("/", reverse=True) + __pow__ = _bin_op("pow") + __rpow__ = _bin_op("pow", reverse=True) + __ge__ = _bin_op(">=") + __le__ = _bin_op("<=") + + def __eq__(self, other: Any) -> "Expression": # type: ignore[override] + """Returns a binary expression with the current column as the left + side and the other expression as the right side. + """ + if isinstance(other, get_args(PrimitiveType)): + other = LiteralExpression(other) + return ScalarFunctionExpression("==", self, other) + def __init__(self) -> None: pass @@ -73,20 +113,6 @@ class LiteralExpression(Expression): return f"Literal({self._value})" -def _bin_op( - name: str, doc: str = "binary function", reverse: bool = False -) -> Callable[["ColumnRef", Any], Expression]: - def _(self: "ColumnRef", other: Any) -> Expression: - if isinstance(other, get_args(PrimitiveType)): - other = LiteralExpression(other) - if not reverse: - return ScalarFunctionExpression(name, self, other) - else: - return ScalarFunctionExpression(name, other, self) - - return _ - - class ColumnRef(Expression): """Represents a column reference. There is no guarantee that this column actually exists. In the context of this project, we refer by its name and @@ -105,32 +131,6 @@ class ColumnRef(Expression): """Returns the qualified name of the column reference.""" return ".".join(self._parts) - __gt__ = _bin_op("gt") - __lt__ = _bin_op("lt") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("modulo") - __radd__ = _bin_op("plus", reverse=True) - __rsub__ = _bin_op("minus", reverse=True) - __rmul__ = _bin_op("multiply", reverse=True) - __rdiv__ = _bin_op("divide", reverse=True) - __rtruediv__ = _bin_op("divide", reverse=True) - __pow__ = _bin_op("pow") - __rpow__ = _bin_op("pow", reverse=True) - __ge__ = _bin_op("greterEquals") - __le__ = _bin_op("lessEquals") - - def __eq__(self, other: Any) -> Expression: # type: ignore[override] - """Returns a binary expression with the current column as the left - side and the other expression as the right side. - """ - if isinstance(other, get_args(PrimitiveType)): - other = LiteralExpression(other) - return ScalarFunctionExpression("eq", self, other) - def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: """Returns the Proto representation of the expression.""" expr = proto.Expression() diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index adc6f38f997..95173de347e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -19,9 +19,12 @@ import uuid import unittest import tempfile +import pandas + from pyspark.sql import SparkSession, Row from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf +from pyspark.sql.connect.functions import lit from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -79,6 +82,15 @@ class SparkConnectTests(SparkConnectSQLTestCase): result = df.explain() self.assertGreater(len(result), 0) + def test_simple_binary_expressions(self): + """Test complex expression""" + df = self.connect.read.table(self.tbl_name) + pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas() + self.assertEqual(len(pd.index), 4) + + res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) + self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 2aa686bbc38..74f5343a9c1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -16,6 +16,7 @@ # from pyspark.testing.connectutils import PlanOnlyTestFixture +from pyspark.sql.connect.proto import Expression as ProtoExpression import pyspark.sql.connect as c import pyspark.sql.connect.plan as p import pyspark.sql.connect.column as col @@ -51,6 +52,34 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): plan = fun.lit(10).to_plan(None) self.assertIs(plan.literal.i32, 10) + def test_column_expressions(self): + """Test a more complex combination of expressions and their translation into + the protobuf structure.""" + df = c.DataFrame.withPlan(p.Read("table")) + + expr = df.id % fun.lit(10) == fun.lit(10) + expr_plan = expr.to_plan(None) + self.assertIsNotNone(expr_plan.unresolved_function) + self.assertEqual(expr_plan.unresolved_function.parts[0], "==") + + lit_fun = expr_plan.unresolved_function.arguments[1] + self.assertIsInstance(lit_fun, ProtoExpression) + self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal) + self.assertEqual(lit_fun.literal.i32, 10) + + mod_fun = expr_plan.unresolved_function.arguments[0] + self.assertIsInstance(mod_fun, ProtoExpression) + self.assertIsInstance(mod_fun.unresolved_function, ProtoExpression.UnresolvedFunction) + self.assertEqual(len(mod_fun.unresolved_function.arguments), 2) + self.assertIsInstance(mod_fun.unresolved_function.arguments[0], ProtoExpression) + self.assertIsInstance( + mod_fun.unresolved_function.arguments[0].unresolved_attribute, + ProtoExpression.UnresolvedAttribute, + ) + self.assertEqual( + mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"] + ) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 8fb33beb367..03cedd56de5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -40,7 +40,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction ) ) - self.assertEqual(plan.root.filter.condition.unresolved_function.parts, ["gt"]) + self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"]) self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2) def test_relation_alias(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org