This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new eac736e1a62 [SPARK-40875][CONNECT] Improve aggregate in Connect DSL eac736e1a62 is described below commit eac736e1a62bf707cd3103a5c94df1d5a45617df Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Mon Nov 7 18:05:59 2022 +0800 [SPARK-40875][CONNECT] Improve aggregate in Connect DSL ### What changes were proposed in this pull request? This PR adds the aggregate expressions (or named result expressions) for Aggregate in Connect proto and DSL. On the server side, this PR also differentiates named expression (e.g. with `alias`) and non-named expression (so server will wraps `UnresolvedAlias` and Catalyst will generate alias for such expression). ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38527 from amaliujia/add_aggregate_expression_to_dsl. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 7 +-- .../org/apache/spark/sql/connect/dsl/package.scala | 12 +++++- .../sql/connect/planner/SparkConnectPlanner.scala | 20 ++++----- .../connect/planner/SparkConnectPlannerSuite.scala | 15 +++++-- .../connect/planner/SparkConnectProtoSuite.scala | 17 ++++++++ python/pyspark/sql/connect/plan.py | 9 ++-- python/pyspark/sql/connect/proto/relations_pb2.py | 50 +++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 31 ++------------ 8 files changed, 81 insertions(+), 80 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index deb35525728..8edd8911242 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -161,12 +161,7 @@ message Offset { message Aggregate { Relation input = 1; repeated Expression grouping_expressions = 2; - repeated AggregateFunction result_expressions = 3; - - message AggregateFunction { - string name = 1; - repeated Expression arguments = 2; - } + repeated Expression result_expressions = 3; } // Relation of type [[Sort]]. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index e2030c9ad31..c40a9eed753 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -93,6 +93,13 @@ package object dsl { .build() } + def proto_min(e: Expression): Expression = + Expression + .newBuilder() + .setUnresolvedFunction( + Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e)) + .build() + /** * Create an unresolved function from name parts. * @@ -383,8 +390,9 @@ package object dsl { for (groupingExpr <- groupingExprs) { agg.addGroupingExpressions(groupingExpr) } - // TODO: support aggregateExprs, which is blocked by supporting any builtin function - // resolution only by name in the analyzer. + for (aggregateExpr <- aggregateExprs) { + agg.addResultExpressions(aggregateExpr) + } Relation.newBuilder().setAggregate(agg.build()).build() } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f5c6980290f..d2b474711ab 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} @@ -285,7 +285,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { isDistinct = false) } - private def transformAlias(alias: proto.Expression.Alias): Expression = { + private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { Alias(transformExpression(alias.getExpr), alias.getName)() } @@ -393,17 +393,15 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { child = transformRelation(rel.getInput), groupingExpressions = groupingExprs.toSeq, aggregateExpressions = - rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq) + rel.getResultExpressionsList.asScala.map(transformResultExpression).toSeq) } - private def transformAggregateExpression( - exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = { - val fun = exp.getName - UnresolvedAlias( - UnresolvedFunction( - name = fun, - arguments = exp.getArgumentsList.asScala.map(transformExpression).toSeq, - isDistinct = false)) + private def transformResultExpression(exp: proto.Expression): expressions.NamedExpression = { + if (exp.hasAlias) { + transformAlias(exp.getAlias) + } else { + UnresolvedAlias(transformExpression(exp)) + } } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index eda7ade3ec6..d2304581c3a 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -274,12 +274,19 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("left").build()) .build() + val sum = + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts("sum") + .addArguments(unresolvedAttribute)) + .build() + val agg = proto.Aggregate.newBuilder .setInput(readRel) - .addResultExpressions( - proto.Aggregate.AggregateFunction.newBuilder - .setName("sum") - .addArguments(unresolvedAttribute)) + .addResultExpressions(sum) .addGroupingExpressions(unresolvedAttribute) .build() diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 94a2bd12461..0aa89d6f640 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -144,6 +145,22 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("Aggregate expressions") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val connectPlan = + connectTestRelation.groupBy("id".protoAttr)(proto_min("name".protoAttr)) + val sparkPlan = + sparkTestRelation.groupBy(Column("id")).agg(min(Column("name"))) + comparePlans(connectPlan, sparkPlan) + + val connectPlan2 = + connectTestRelation.groupBy("id".protoAttr)(proto_min("name".protoAttr).as("agg1")) + val sparkPlan2 = + sparkTestRelation.groupBy(Column("id")).agg(min(Column("name")).as("agg1")) + comparePlans(connectPlan2, sparkPlan2) + } + } + test("Test as(alias: String)") { val connectPlan = connectTestRelation.as("target_table") val sparkPlan = sparkTestRelation.as("target_table") diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index cc59a493d5a..4b28e6cb80a 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -489,15 +489,16 @@ class Aggregate(LogicalPlan): def _convert_measure( self, m: MeasureType, session: Optional["RemoteSparkSession"] - ) -> proto.Aggregate.AggregateFunction: + ) -> proto.Expression: exp, fun = m - measure = proto.Aggregate.AggregateFunction() - measure.name = fun + proto_expr = proto.Expression() + measure = proto_expr.unresolved_function + measure.parts.append(fun) if type(exp) is str: measure.arguments.append(self.unresolved_attr(exp)) else: measure.arguments.append(cast(Expression, exp).to_plan(session)) - return measure + return proto_expr def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: assert self._child is not None diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 3d5eb53e5a9..6180c5e13c9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -76,29 +76,27 @@ if _descriptor._USE_C_DESCRIPTORS == False: _OFFSET._serialized_start = 2626 _OFFSET._serialized_end = 2705 _AGGREGATE._serialized_start = 2708 - _AGGREGATE._serialized_end = 3033 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2937 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 3033 - _SORT._serialized_start = 3036 - _SORT._serialized_end = 3567 - _SORT_SORTFIELD._serialized_start = 3185 - _SORT_SORTFIELD._serialized_end = 3373 - _SORT_SORTDIRECTION._serialized_start = 3375 - _SORT_SORTDIRECTION._serialized_end = 3483 - _SORT_SORTNULLS._serialized_start = 3485 - _SORT_SORTNULLS._serialized_end = 3567 - _DEDUPLICATE._serialized_start = 3570 - _DEDUPLICATE._serialized_end = 3712 - _LOCALRELATION._serialized_start = 3714 - _LOCALRELATION._serialized_end = 3807 - _SAMPLE._serialized_start = 3810 - _SAMPLE._serialized_end = 4050 - _SAMPLE_SEED._serialized_start = 4024 - _SAMPLE_SEED._serialized_end = 4050 - _RANGE._serialized_start = 4053 - _RANGE._serialized_end = 4251 - _RANGE_NUMPARTITIONS._serialized_start = 4197 - _RANGE_NUMPARTITIONS._serialized_end = 4251 - _SUBQUERYALIAS._serialized_start = 4253 - _SUBQUERYALIAS._serialized_end = 4367 + _AGGREGATE._serialized_end = 2918 + _SORT._serialized_start = 2921 + _SORT._serialized_end = 3452 + _SORT_SORTFIELD._serialized_start = 3070 + _SORT_SORTFIELD._serialized_end = 3258 + _SORT_SORTDIRECTION._serialized_start = 3260 + _SORT_SORTDIRECTION._serialized_end = 3368 + _SORT_SORTNULLS._serialized_start = 3370 + _SORT_SORTNULLS._serialized_end = 3452 + _DEDUPLICATE._serialized_start = 3455 + _DEDUPLICATE._serialized_end = 3597 + _LOCALRELATION._serialized_start = 3599 + _LOCALRELATION._serialized_end = 3692 + _SAMPLE._serialized_start = 3695 + _SAMPLE._serialized_end = 3935 + _SAMPLE_SEED._serialized_start = 3909 + _SAMPLE_SEED._serialized_end = 3935 + _RANGE._serialized_start = 3938 + _RANGE._serialized_end = 4136 + _RANGE_NUMPARTITIONS._serialized_start = 4082 + _RANGE_NUMPARTITIONS._serialized_end = 4136 + _SUBQUERYALIAS._serialized_start = 4138 + _SUBQUERYALIAS._serialized_end = 4252 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 60f4e2033a8..f5b5c9f90dc 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -661,31 +661,6 @@ class Aggregate(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class AggregateFunction(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NAME_FIELD_NUMBER: builtins.int - ARGUMENTS_FIELD_NUMBER: builtins.int - name: builtins.str - @property - def arguments( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - pyspark.sql.connect.proto.expressions_pb2.Expression - ]: ... - def __init__( - self, - *, - name: builtins.str = ..., - arguments: collections.abc.Iterable[ - pyspark.sql.connect.proto.expressions_pb2.Expression - ] - | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["arguments", b"arguments", "name", b"name"] - ) -> None: ... - INPUT_FIELD_NUMBER: builtins.int GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int RESULT_EXPRESSIONS_FIELD_NUMBER: builtins.int @@ -701,7 +676,7 @@ class Aggregate(google.protobuf.message.Message): def result_expressions( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - global___Aggregate.AggregateFunction + pyspark.sql.connect.proto.expressions_pb2.Expression ]: ... def __init__( self, @@ -711,7 +686,9 @@ class Aggregate(google.protobuf.message.Message): pyspark.sql.connect.proto.expressions_pb2.Expression ] | None = ..., - result_expressions: collections.abc.Iterable[global___Aggregate.AggregateFunction] + result_expressions: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ] | None = ..., ) -> None: ... def HasField( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org