This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eac736e1a62 [SPARK-40875][CONNECT] Improve aggregate in Connect DSL
eac736e1a62 is described below

commit eac736e1a62bf707cd3103a5c94df1d5a45617df
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Mon Nov 7 18:05:59 2022 +0800

    [SPARK-40875][CONNECT] Improve aggregate in Connect DSL
    
    ### What changes were proposed in this pull request?
    
    This PR adds the aggregate expressions (or named result expressions) for 
Aggregate in Connect proto and DSL. On the server side, this PR also 
differentiates named expression (e.g. with `alias`) and non-named expression 
(so server will wraps `UnresolvedAlias` and Catalyst will generate alias for 
such expression).
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38527 from amaliujia/add_aggregate_expression_to_dsl.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  7 +--
 .../org/apache/spark/sql/connect/dsl/package.scala | 12 +++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 20 ++++-----
 .../connect/planner/SparkConnectPlannerSuite.scala | 15 +++++--
 .../connect/planner/SparkConnectProtoSuite.scala   | 17 ++++++++
 python/pyspark/sql/connect/plan.py                 |  9 ++--
 python/pyspark/sql/connect/proto/relations_pb2.py  | 50 +++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 31 ++------------
 8 files changed, 81 insertions(+), 80 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index deb35525728..8edd8911242 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -161,12 +161,7 @@ message Offset {
 message Aggregate {
   Relation input = 1;
   repeated Expression grouping_expressions = 2;
-  repeated AggregateFunction result_expressions = 3;
-
-  message AggregateFunction {
-    string name = 1;
-    repeated Expression arguments = 2;
-  }
+  repeated Expression result_expressions = 3;
 }
 
 // Relation of type [[Sort]].
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index e2030c9ad31..c40a9eed753 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -93,6 +93,13 @@ package object dsl {
           .build()
     }
 
+    def proto_min(e: Expression): Expression =
+      Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          
Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e))
+        .build()
+
     /**
      * Create an unresolved function from name parts.
      *
@@ -383,8 +390,9 @@ package object dsl {
         for (groupingExpr <- groupingExprs) {
           agg.addGroupingExpressions(groupingExpr)
         }
-        // TODO: support aggregateExprs, which is blocked by supporting any 
builtin function
-        // resolution only by name in the analyzer.
+        for (aggregateExpr <- aggregateExprs) {
+          agg.addResultExpressions(aggregateExpr)
+        }
         Relation.newBuilder().setAggregate(agg.build()).build()
       }
 
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f5c6980290f..d2b474711ab 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.AliasIdentifier
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Expression, NamedExpression}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
@@ -285,7 +285,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       isDistinct = false)
   }
 
-  private def transformAlias(alias: proto.Expression.Alias): Expression = {
+  private def transformAlias(alias: proto.Expression.Alias): NamedExpression = 
{
     Alias(transformExpression(alias.getExpr), alias.getName)()
   }
 
@@ -393,17 +393,15 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
       child = transformRelation(rel.getInput),
       groupingExpressions = groupingExprs.toSeq,
       aggregateExpressions =
-        
rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq)
+        
rel.getResultExpressionsList.asScala.map(transformResultExpression).toSeq)
   }
 
-  private def transformAggregateExpression(
-      exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = {
-    val fun = exp.getName
-    UnresolvedAlias(
-      UnresolvedFunction(
-        name = fun,
-        arguments = 
exp.getArgumentsList.asScala.map(transformExpression).toSeq,
-        isDistinct = false))
+  private def transformResultExpression(exp: proto.Expression): 
expressions.NamedExpression = {
+    if (exp.hasAlias) {
+      transformAlias(exp.getAlias)
+    } else {
+      UnresolvedAlias(transformExpression(exp))
+    }
   }
 
 }
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index eda7ade3ec6..d2304581c3a 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -274,12 +274,19 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
         
proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("left").build())
       .build()
 
+    val sum =
+      proto.Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          proto.Expression.UnresolvedFunction
+            .newBuilder()
+            .addParts("sum")
+            .addArguments(unresolvedAttribute))
+        .build()
+
     val agg = proto.Aggregate.newBuilder
       .setInput(readRel)
-      .addResultExpressions(
-        proto.Aggregate.AggregateFunction.newBuilder
-          .setName("sum")
-          .addArguments(unresolvedAttribute))
+      .addResultExpressions(sum)
       .addGroupingExpressions(unresolvedAttribute)
       .build()
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 94a2bd12461..0aa89d6f640 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connect.dsl.MockRemoteSession
 import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 
@@ -144,6 +145,22 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     }
   }
 
+  test("Aggregate expressions") {
+    withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
+      val connectPlan =
+        
connectTestRelation.groupBy("id".protoAttr)(proto_min("name".protoAttr))
+      val sparkPlan =
+        sparkTestRelation.groupBy(Column("id")).agg(min(Column("name")))
+      comparePlans(connectPlan, sparkPlan)
+
+      val connectPlan2 =
+        
connectTestRelation.groupBy("id".protoAttr)(proto_min("name".protoAttr).as("agg1"))
+      val sparkPlan2 =
+        
sparkTestRelation.groupBy(Column("id")).agg(min(Column("name")).as("agg1"))
+      comparePlans(connectPlan2, sparkPlan2)
+    }
+  }
+
   test("Test as(alias: String)") {
     val connectPlan = connectTestRelation.as("target_table")
     val sparkPlan = sparkTestRelation.as("target_table")
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index cc59a493d5a..4b28e6cb80a 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -489,15 +489,16 @@ class Aggregate(LogicalPlan):
 
     def _convert_measure(
         self, m: MeasureType, session: Optional["RemoteSparkSession"]
-    ) -> proto.Aggregate.AggregateFunction:
+    ) -> proto.Expression:
         exp, fun = m
-        measure = proto.Aggregate.AggregateFunction()
-        measure.name = fun
+        proto_expr = proto.Expression()
+        measure = proto_expr.unresolved_function
+        measure.parts.append(fun)
         if type(exp) is str:
             measure.arguments.append(self.unresolved_attr(exp))
         else:
             measure.arguments.append(cast(Expression, exp).to_plan(session))
-        return measure
+        return proto_expr
 
     def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
         assert self._child is not None
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 3d5eb53e5a9..6180c5e13c9 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -76,29 +76,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _OFFSET._serialized_start = 2626
     _OFFSET._serialized_end = 2705
     _AGGREGATE._serialized_start = 2708
-    _AGGREGATE._serialized_end = 3033
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2937
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 3033
-    _SORT._serialized_start = 3036
-    _SORT._serialized_end = 3567
-    _SORT_SORTFIELD._serialized_start = 3185
-    _SORT_SORTFIELD._serialized_end = 3373
-    _SORT_SORTDIRECTION._serialized_start = 3375
-    _SORT_SORTDIRECTION._serialized_end = 3483
-    _SORT_SORTNULLS._serialized_start = 3485
-    _SORT_SORTNULLS._serialized_end = 3567
-    _DEDUPLICATE._serialized_start = 3570
-    _DEDUPLICATE._serialized_end = 3712
-    _LOCALRELATION._serialized_start = 3714
-    _LOCALRELATION._serialized_end = 3807
-    _SAMPLE._serialized_start = 3810
-    _SAMPLE._serialized_end = 4050
-    _SAMPLE_SEED._serialized_start = 4024
-    _SAMPLE_SEED._serialized_end = 4050
-    _RANGE._serialized_start = 4053
-    _RANGE._serialized_end = 4251
-    _RANGE_NUMPARTITIONS._serialized_start = 4197
-    _RANGE_NUMPARTITIONS._serialized_end = 4251
-    _SUBQUERYALIAS._serialized_start = 4253
-    _SUBQUERYALIAS._serialized_end = 4367
+    _AGGREGATE._serialized_end = 2918
+    _SORT._serialized_start = 2921
+    _SORT._serialized_end = 3452
+    _SORT_SORTFIELD._serialized_start = 3070
+    _SORT_SORTFIELD._serialized_end = 3258
+    _SORT_SORTDIRECTION._serialized_start = 3260
+    _SORT_SORTDIRECTION._serialized_end = 3368
+    _SORT_SORTNULLS._serialized_start = 3370
+    _SORT_SORTNULLS._serialized_end = 3452
+    _DEDUPLICATE._serialized_start = 3455
+    _DEDUPLICATE._serialized_end = 3597
+    _LOCALRELATION._serialized_start = 3599
+    _LOCALRELATION._serialized_end = 3692
+    _SAMPLE._serialized_start = 3695
+    _SAMPLE._serialized_end = 3935
+    _SAMPLE_SEED._serialized_start = 3909
+    _SAMPLE_SEED._serialized_end = 3935
+    _RANGE._serialized_start = 3938
+    _RANGE._serialized_end = 4136
+    _RANGE_NUMPARTITIONS._serialized_start = 4082
+    _RANGE_NUMPARTITIONS._serialized_end = 4136
+    _SUBQUERYALIAS._serialized_start = 4138
+    _SUBQUERYALIAS._serialized_end = 4252
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 60f4e2033a8..f5b5c9f90dc 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -661,31 +661,6 @@ class Aggregate(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-    class AggregateFunction(google.protobuf.message.Message):
-        DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-        NAME_FIELD_NUMBER: builtins.int
-        ARGUMENTS_FIELD_NUMBER: builtins.int
-        name: builtins.str
-        @property
-        def arguments(
-            self,
-        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-            pyspark.sql.connect.proto.expressions_pb2.Expression
-        ]: ...
-        def __init__(
-            self,
-            *,
-            name: builtins.str = ...,
-            arguments: collections.abc.Iterable[
-                pyspark.sql.connect.proto.expressions_pb2.Expression
-            ]
-            | None = ...,
-        ) -> None: ...
-        def ClearField(
-            self, field_name: typing_extensions.Literal["arguments", 
b"arguments", "name", b"name"]
-        ) -> None: ...
-
     INPUT_FIELD_NUMBER: builtins.int
     GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
     RESULT_EXPRESSIONS_FIELD_NUMBER: builtins.int
@@ -701,7 +676,7 @@ class Aggregate(google.protobuf.message.Message):
     def result_expressions(
         self,
     ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-        global___Aggregate.AggregateFunction
+        pyspark.sql.connect.proto.expressions_pb2.Expression
     ]: ...
     def __init__(
         self,
@@ -711,7 +686,9 @@ class Aggregate(google.protobuf.message.Message):
             pyspark.sql.connect.proto.expressions_pb2.Expression
         ]
         | None = ...,
-        result_expressions: 
collections.abc.Iterable[global___Aggregate.AggregateFunction]
+        result_expressions: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
         | None = ...,
     ) -> None: ...
     def HasField(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to