This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a9da92498f0 [SPARK-40538][CONNECT] Improve built-in function support 
for Python client
a9da92498f0 is described below

commit a9da92498f0968eab21590845abbf1987ee9f1cd
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Tue Oct 18 20:08:36 2022 +0900

    [SPARK-40538][CONNECT] Improve built-in function support for Python client
    
    ### What changes were proposed in this pull request?
    This patch changes the way simple scalar built-in functions are resolved in 
the Python Spark Connect client. Previously, it was trying to manually load 
specific functions. With the changes in this patch, the trivial binary 
operators like `<`, `+`, ... are mapped to their name equivalents in Spark so 
that the dynamic function lookup works.
    
    In addition, it cleans up the Scala planner side to remove the now 
unnecessary code translating the trivial binary expressions into their 
equivalent functions.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT, E2E
    
    Closes #38270 from grundprinzip/spark-40538.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/connect/dsl/package.scala | 38 ++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 25 ++-----
 .../connect/planner/SparkConnectPlannerSuite.scala |  2 +-
 .../connect/planner/SparkConnectProtoSuite.scala   | 28 ++++++++
 python/pyspark/sql/connect/column.py               | 80 +++++++++++-----------
 .../sql/tests/connect/test_connect_basic.py        | 12 ++++
 .../connect/test_connect_column_expressions.py     | 29 ++++++++
 .../sql/tests/connect/test_connect_plan_only.py    |  2 +-
 8 files changed, 156 insertions(+), 60 deletions(-)

diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 579f190156f..0c392130562 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -92,6 +92,44 @@ package object dsl {
           .build()
     }
 
+    /**
+     * Create an unresolved function from name parts.
+     *
+     * @param nameParts
+     * @param args
+     * @return
+     *   Expression wrapping the unresolved function.
+     */
+    def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): 
proto.Expression = {
+      proto.Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          proto.Expression.UnresolvedFunction
+            .newBuilder()
+            .addAllParts(nameParts.asJava)
+            .addAllArguments(args.asJava))
+        .build()
+    }
+
+    /**
+     * Creates an UnresolvedFunction from a single identifier.
+     *
+     * @param name
+     * @param args
+     * @return
+     *   Expression wrapping the unresolved function.
+     */
+    def callFunction(name: String, args: Seq[proto.Expression]): 
proto.Expression = {
+      proto.Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          proto.Expression.UnresolvedFunction
+            .newBuilder()
+            .addParts(name)
+            .addAllArguments(args.asJava))
+        .build()
+    }
+
     implicit def intToLiteral(i: Int): proto.Expression =
       proto.Expression
         .newBuilder()
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 61352c17a23..7ffce908221 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -197,10 +197,6 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       limitExpr = expressions.Literal(limit.getLimit, IntegerType))
   }
 
-  private def lookupFunction(name: String, args: Seq[Expression]): Expression 
= {
-    UnresolvedFunction(Seq(name), args, isDistinct = false)
-  }
-
   /**
    * Translates a scalar function from proto to the Catalyst expression.
    *
@@ -211,21 +207,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
    * @return
    */
   private def transformScalarFunction(fun: 
proto.Expression.UnresolvedFunction): Expression = {
-    val funName = fun.getPartsList.asScala.mkString(".")
-    funName match {
-      case "gt" =>
-        assert(fun.getArgumentsCount == 2, "`gt` function must have two 
arguments.")
-        expressions.GreaterThan(
-          transformExpression(fun.getArguments(0)),
-          transformExpression(fun.getArguments(1)))
-      case "eq" =>
-        assert(fun.getArgumentsCount == 2, "`eq` function must have two 
arguments.")
-        expressions.EqualTo(
-          transformExpression(fun.getArguments(0)),
-          transformExpression(fun.getArguments(1)))
-      case _ =>
-        lookupFunction(funName, 
fun.getArgumentsList.asScala.map(transformExpression).toSeq)
+    if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) {
+      throw new IllegalArgumentException(
+        "Function identifier must be passed as sequence of name parts.")
     }
+    UnresolvedFunction(
+      fun.getPartsList.asScala.toSeq,
+      fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+      isDistinct = false)
   }
 
   private def transformAlias(alias: proto.Expression.Alias): Expression = {
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 67518f3bdb1..74788ce5593 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -197,7 +197,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
 
     val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
       proto.Expression.UnresolvedFunction.newBuilder
-        .addAllParts(Seq("eq").asJava)
+        .addAllParts(Seq("==").asJava)
         .addArguments(unresolvedAttribute)
         .addArguments(unresolvedAttribute)
         .build())
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index b13e74c2125..ef4b358798e 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
   }
 
+  test("UnresolvedFunction resolution.") {
+    {
+      import org.apache.spark.sql.connect.dsl.expressions._
+      import org.apache.spark.sql.connect.dsl.plans._
+      assertThrows[IllegalArgumentException] {
+        transform(connectTestRelation.select(callFunction("default.hex", 
Seq("id".protoAttr))))
+      }
+    }
+
+    val connectPlan = {
+      import org.apache.spark.sql.connect.dsl.expressions._
+      import org.apache.spark.sql.connect.dsl.plans._
+      transform(
+        connectTestRelation.select(callFunction(Seq("default", "hex"), 
Seq("id".protoAttr))))
+    }
+
+    assertThrows[UnsupportedOperationException] {
+      connectPlan.analyze
+    }
+
+    val validPlan = {
+      import org.apache.spark.sql.connect.dsl.expressions._
+      import org.apache.spark.sql.connect.dsl.plans._
+      transform(connectTestRelation.select(callFunction(Seq("hex"), 
Seq("id".protoAttr))))
+    }
+    assert(validPlan.analyze != null)
+  }
+
   test("Basic filter") {
     val connectPlan = {
       import org.apache.spark.sql.connect.dsl.expressions._
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 55b8f176e05..b291c5fb211 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -26,11 +26,51 @@ if TYPE_CHECKING:
     import pyspark.sql.connect.proto as proto
 
 
+def _bin_op(
+    name: str, doc: str = "binary function", reverse: bool = False
+) -> Callable[["ColumnRef", Any], "Expression"]:
+    def _(self: "ColumnRef", other: Any) -> "Expression":
+        if isinstance(other, get_args(PrimitiveType)):
+            other = LiteralExpression(other)
+        if not reverse:
+            return ScalarFunctionExpression(name, self, other)
+        else:
+            return ScalarFunctionExpression(name, other, self)
+
+    return _
+
+
 class Expression(object):
     """
     Expression base class.
     """
 
+    __gt__ = _bin_op(">")
+    __lt__ = _bin_op(">")
+    __add__ = _bin_op("+")
+    __sub__ = _bin_op("-")
+    __mul__ = _bin_op("*")
+    __div__ = _bin_op("/")
+    __truediv__ = _bin_op("/")
+    __mod__ = _bin_op("%")
+    __radd__ = _bin_op("+", reverse=True)
+    __rsub__ = _bin_op("-", reverse=True)
+    __rmul__ = _bin_op("*", reverse=True)
+    __rdiv__ = _bin_op("/", reverse=True)
+    __rtruediv__ = _bin_op("/", reverse=True)
+    __pow__ = _bin_op("pow")
+    __rpow__ = _bin_op("pow", reverse=True)
+    __ge__ = _bin_op(">=")
+    __le__ = _bin_op("<=")
+
+    def __eq__(self, other: Any) -> "Expression":  # type: ignore[override]
+        """Returns a binary expression with the current column as the left
+        side and the other expression as the right side.
+        """
+        if isinstance(other, get_args(PrimitiveType)):
+            other = LiteralExpression(other)
+        return ScalarFunctionExpression("==", self, other)
+
     def __init__(self) -> None:
         pass
 
@@ -73,20 +113,6 @@ class LiteralExpression(Expression):
         return f"Literal({self._value})"
 
 
-def _bin_op(
-    name: str, doc: str = "binary function", reverse: bool = False
-) -> Callable[["ColumnRef", Any], Expression]:
-    def _(self: "ColumnRef", other: Any) -> Expression:
-        if isinstance(other, get_args(PrimitiveType)):
-            other = LiteralExpression(other)
-        if not reverse:
-            return ScalarFunctionExpression(name, self, other)
-        else:
-            return ScalarFunctionExpression(name, other, self)
-
-    return _
-
-
 class ColumnRef(Expression):
     """Represents a column reference. There is no guarantee that this column
     actually exists. In the context of this project, we refer by its name and
@@ -105,32 +131,6 @@ class ColumnRef(Expression):
         """Returns the qualified name of the column reference."""
         return ".".join(self._parts)
 
-    __gt__ = _bin_op("gt")
-    __lt__ = _bin_op("lt")
-    __add__ = _bin_op("plus")
-    __sub__ = _bin_op("minus")
-    __mul__ = _bin_op("multiply")
-    __div__ = _bin_op("divide")
-    __truediv__ = _bin_op("divide")
-    __mod__ = _bin_op("modulo")
-    __radd__ = _bin_op("plus", reverse=True)
-    __rsub__ = _bin_op("minus", reverse=True)
-    __rmul__ = _bin_op("multiply", reverse=True)
-    __rdiv__ = _bin_op("divide", reverse=True)
-    __rtruediv__ = _bin_op("divide", reverse=True)
-    __pow__ = _bin_op("pow")
-    __rpow__ = _bin_op("pow", reverse=True)
-    __ge__ = _bin_op("greterEquals")
-    __le__ = _bin_op("lessEquals")
-
-    def __eq__(self, other: Any) -> Expression:  # type: ignore[override]
-        """Returns a binary expression with the current column as the left
-        side and the other expression as the right side.
-        """
-        if isinstance(other, get_args(PrimitiveType)):
-            other = LiteralExpression(other)
-        return ScalarFunctionExpression("eq", self, other)
-
     def to_plan(self, session: Optional["RemoteSparkSession"]) -> 
proto.Expression:
         """Returns the Proto representation of the expression."""
         expr = proto.Expression()
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index adc6f38f997..95173de347e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -19,9 +19,12 @@ import uuid
 import unittest
 import tempfile
 
+import pandas
+
 from pyspark.sql import SparkSession, Row
 from pyspark.sql.connect.client import RemoteSparkSession
 from pyspark.sql.connect.function_builder import udf
+from pyspark.sql.connect.functions import lit
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import ReusedPySparkTestCase
 
@@ -79,6 +82,15 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         result = df.explain()
         self.assertGreater(len(result), 0)
 
+    def test_simple_binary_expressions(self):
+        """Test complex expression"""
+        df = self.connect.read.table(self.tbl_name)
+        pd = df.select(df.id).where(df.id % lit(30) == 
lit(0)).sort(df.id.asc()).toPandas()
+        self.assertEqual(len(pd.index), 4)
+
+        res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
+        self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_basic import *  # noqa: F401
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 2aa686bbc38..74f5343a9c1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -16,6 +16,7 @@
 #
 
 from pyspark.testing.connectutils import PlanOnlyTestFixture
+from pyspark.sql.connect.proto import Expression as ProtoExpression
 import pyspark.sql.connect as c
 import pyspark.sql.connect.plan as p
 import pyspark.sql.connect.column as col
@@ -51,6 +52,34 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         plan = fun.lit(10).to_plan(None)
         self.assertIs(plan.literal.i32, 10)
 
+    def test_column_expressions(self):
+        """Test a more complex combination of expressions and their 
translation into
+        the protobuf structure."""
+        df = c.DataFrame.withPlan(p.Read("table"))
+
+        expr = df.id % fun.lit(10) == fun.lit(10)
+        expr_plan = expr.to_plan(None)
+        self.assertIsNotNone(expr_plan.unresolved_function)
+        self.assertEqual(expr_plan.unresolved_function.parts[0], "==")
+
+        lit_fun = expr_plan.unresolved_function.arguments[1]
+        self.assertIsInstance(lit_fun, ProtoExpression)
+        self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
+        self.assertEqual(lit_fun.literal.i32, 10)
+
+        mod_fun = expr_plan.unresolved_function.arguments[0]
+        self.assertIsInstance(mod_fun, ProtoExpression)
+        self.assertIsInstance(mod_fun.unresolved_function, 
ProtoExpression.UnresolvedFunction)
+        self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
+        self.assertIsInstance(mod_fun.unresolved_function.arguments[0], 
ProtoExpression)
+        self.assertIsInstance(
+            mod_fun.unresolved_function.arguments[0].unresolved_attribute,
+            ProtoExpression.UnresolvedAttribute,
+        )
+        self.assertEqual(
+            
mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"]
+        )
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 8fb33beb367..03cedd56de5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -40,7 +40,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
                 plan.root.filter.condition.unresolved_function, 
proto.Expression.UnresolvedFunction
             )
         )
-        self.assertEqual(plan.root.filter.condition.unresolved_function.parts, 
["gt"])
+        self.assertEqual(plan.root.filter.condition.unresolved_function.parts, 
[">"])
         
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 
2)
 
     def test_relation_alias(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to