This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c85c294d689 [SPARK-40971][CONNECT][DSL] Do not need to use `proto.` to 
refer generated classes in Connect DSL
c85c294d689 is described below

commit c85c294d689ffcd66d5aafe6f54b79566af70dd7
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Mon Oct 31 19:21:30 2022 +0800

    [SPARK-40971][CONNECT][DSL] Do not need to use `proto.` to refer generated 
classes in Connect DSL
    
    ### What changes were proposed in this pull request?
    
    The Connect DSL only deal with Connect proto, thus there is no need to use 
`proto.` to refer to those proto generates class. Instead, we can import more 
from the proto package to simplify the DSL code. There won't be ambiguity to 
class references after this change.
    
    ### Why are the changes needed?
    
    Codebase simplification.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    Existing UT
    
    Closes #38445 from amaliujia/clean_up_dsl.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/connect/dsl/package.scala | 173 ++++++++++-----------
 1 file changed, 82 insertions(+), 91 deletions(-)

diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 9ffc4c4a1fe..3ba773e4c04 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect
 import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 
-import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto._
 import org.apache.spark.connect.proto.Join.JoinType
 import org.apache.spark.connect.proto.SetOperation.SetOpType
 import org.apache.spark.sql.SaveMode
@@ -36,55 +36,53 @@ package object dsl {
 
   object expressions { // scalastyle:ignore
     implicit class DslString(val s: String) {
-      def protoAttr: proto.Expression =
-        proto.Expression
+      def protoAttr: Expression =
+        Expression
           .newBuilder()
           .setUnresolvedAttribute(
-            proto.Expression.UnresolvedAttribute
+            Expression.UnresolvedAttribute
               .newBuilder()
               .setUnparsedIdentifier(s))
           .build()
 
-      def struct(
-          attrs: proto.Expression.QualifiedAttribute*): 
proto.Expression.QualifiedAttribute = {
-        val structExpr = proto.DataType.Struct.newBuilder()
+      def struct(attrs: Expression.QualifiedAttribute*): 
Expression.QualifiedAttribute = {
+        val structExpr = DataType.Struct.newBuilder()
         for (attr <- attrs) {
-          val structField = proto.DataType.StructField.newBuilder()
+          val structField = DataType.StructField.newBuilder()
           structField.setName(attr.getName)
           structField.setType(attr.getType)
           structExpr.addFields(structField)
         }
-        proto.Expression.QualifiedAttribute
+        Expression.QualifiedAttribute
           .newBuilder()
           .setName(s)
-          .setType(proto.DataType.newBuilder().setStruct(structExpr))
+          .setType(DataType.newBuilder().setStruct(structExpr))
           .build()
       }
 
       /** Creates a new AttributeReference of type int */
-      def int: proto.Expression.QualifiedAttribute = 
protoQualifiedAttrWithType(
-        
proto.DataType.newBuilder().setI32(proto.DataType.I32.newBuilder()).build())
+      def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType(
+        DataType.newBuilder().setI32(DataType.I32.newBuilder()).build())
 
-      private def protoQualifiedAttrWithType(
-          dataType: proto.DataType): proto.Expression.QualifiedAttribute =
-        proto.Expression.QualifiedAttribute
+      private def protoQualifiedAttrWithType(dataType: DataType): 
Expression.QualifiedAttribute =
+        Expression.QualifiedAttribute
           .newBuilder()
           .setName(s)
           .setType(dataType)
           .build()
     }
 
-    implicit class DslExpression(val expr: proto.Expression) {
-      def as(alias: String): proto.Expression = proto.Expression
+    implicit class DslExpression(val expr: Expression) {
+      def as(alias: String): Expression = Expression
         .newBuilder()
-        
.setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr))
+        .setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr))
         .build()
 
-      def <(other: proto.Expression): proto.Expression =
-        proto.Expression
+      def <(other: Expression): Expression =
+        Expression
           .newBuilder()
           .setUnresolvedFunction(
-            proto.Expression.UnresolvedFunction
+            Expression.UnresolvedFunction
               .newBuilder()
               .addParts("<")
               .addArguments(expr)
@@ -100,11 +98,11 @@ package object dsl {
      * @return
      *   Expression wrapping the unresolved function.
      */
-    def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): 
proto.Expression = {
-      proto.Expression
+    def callFunction(nameParts: Seq[String], args: Seq[Expression]): 
Expression = {
+      Expression
         .newBuilder()
         .setUnresolvedFunction(
-          proto.Expression.UnresolvedFunction
+          Expression.UnresolvedFunction
             .newBuilder()
             .addAllParts(nameParts.asJava)
             .addAllArguments(args.asJava))
@@ -119,26 +117,26 @@ package object dsl {
      * @return
      *   Expression wrapping the unresolved function.
      */
-    def callFunction(name: String, args: Seq[proto.Expression]): 
proto.Expression = {
-      proto.Expression
+    def callFunction(name: String, args: Seq[Expression]): Expression = {
+      Expression
         .newBuilder()
         .setUnresolvedFunction(
-          proto.Expression.UnresolvedFunction
+          Expression.UnresolvedFunction
             .newBuilder()
             .addParts(name)
             .addAllArguments(args.asJava))
         .build()
     }
 
-    implicit def intToLiteral(i: Int): proto.Expression =
-      proto.Expression
+    implicit def intToLiteral(i: Int): Expression =
+      Expression
         .newBuilder()
-        .setLiteral(proto.Expression.Literal.newBuilder().setI32(i))
+        .setLiteral(Expression.Literal.newBuilder().setI32(i))
         .build()
   }
 
   object commands { // scalastyle:ignore
-    implicit class DslCommands(val logicalPlan: proto.Relation) {
+    implicit class DslCommands(val logicalPlan: Relation) {
       def write(
           format: Option[String] = None,
           path: Option[String] = None,
@@ -147,8 +145,8 @@ package object dsl {
           sortByColumns: Seq[String] = Seq.empty,
           partitionByCols: Seq[String] = Seq.empty,
           bucketByCols: Seq[String] = Seq.empty,
-          numBuckets: Option[Int] = None): proto.Command = {
-        val writeOp = proto.WriteOperation.newBuilder()
+          numBuckets: Option[Int] = None): Command = {
+        val writeOp = WriteOperation.newBuilder()
         format.foreach(writeOp.setSource(_))
 
         mode
@@ -165,24 +163,24 @@ package object dsl {
         partitionByCols.foreach(writeOp.addPartitioningColumns(_))
 
         if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
-          val op = proto.WriteOperation.BucketBy.newBuilder()
+          val op = WriteOperation.BucketBy.newBuilder()
           numBuckets.foreach(op.setNumBuckets(_))
           bucketByCols.foreach(op.addBucketColumnNames(_))
           writeOp.setBucketBy(op.build())
         }
         writeOp.setInput(logicalPlan)
-        proto.Command.newBuilder().setWriteOperation(writeOp.build()).build()
+        Command.newBuilder().setWriteOperation(writeOp.build()).build()
       }
     }
   }
 
   object plans { // scalastyle:ignore
-    implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
-      def select(exprs: proto.Expression*): proto.Relation = {
-        proto.Relation
+    implicit class DslLogicalPlan(val logicalPlan: Relation) {
+      def select(exprs: Expression*): Relation = {
+        Relation
           .newBuilder()
           .setProject(
-            proto.Project
+            Project
               .newBuilder()
               .setInput(logicalPlan)
               .addAllExpressions(exprs.toIterable.asJava)
@@ -190,88 +188,85 @@ package object dsl {
           .build()
       }
 
-      def limit(limit: Int): proto.Relation = {
-        proto.Relation
+      def limit(limit: Int): Relation = {
+        Relation
           .newBuilder()
           .setLimit(
-            proto.Limit
+            Limit
               .newBuilder()
               .setInput(logicalPlan)
               .setLimit(limit))
           .build()
       }
 
-      def offset(offset: Int): proto.Relation = {
-        proto.Relation
+      def offset(offset: Int): Relation = {
+        Relation
           .newBuilder()
           .setOffset(
-            proto.Offset
+            Offset
               .newBuilder()
               .setInput(logicalPlan)
               .setOffset(offset))
           .build()
       }
 
-      def where(condition: proto.Expression): proto.Relation = {
-        proto.Relation
+      def where(condition: Expression): Relation = {
+        Relation
           .newBuilder()
-          
.setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
+          
.setFilter(Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
           .build()
       }
 
-      def deduplicate(colNames: Seq[String]): proto.Relation =
-        proto.Relation
+      def deduplicate(colNames: Seq[String]): Relation =
+        Relation
           .newBuilder()
           .setDeduplicate(
-            proto.Deduplicate
+            Deduplicate
               .newBuilder()
               .setInput(logicalPlan)
               .addAllColumnNames(colNames.asJava))
           .build()
 
-      def distinct(): proto.Relation =
-        proto.Relation
+      def distinct(): Relation =
+        Relation
           .newBuilder()
           .setDeduplicate(
-            proto.Deduplicate
+            Deduplicate
               .newBuilder()
               .setInput(logicalPlan)
               .setAllColumnsAsKeys(true))
           .build()
 
       def join(
-          otherPlan: proto.Relation,
+          otherPlan: Relation,
           joinType: JoinType,
-          condition: Option[proto.Expression]): proto.Relation = {
+          condition: Option[Expression]): Relation = {
         join(otherPlan, joinType, Seq(), condition)
       }
 
-      def join(otherPlan: proto.Relation, condition: 
Option[proto.Expression]): proto.Relation = {
+      def join(otherPlan: Relation, condition: Option[Expression]): Relation = 
{
         join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition)
       }
 
-      def join(otherPlan: proto.Relation): proto.Relation = {
+      def join(otherPlan: Relation): Relation = {
         join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None)
       }
 
-      def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation 
= {
+      def join(otherPlan: Relation, joinType: JoinType): Relation = {
         join(otherPlan, joinType, Seq(), None)
       }
 
-      def join(
-          otherPlan: proto.Relation,
-          joinType: JoinType,
-          usingColumns: Seq[String]): proto.Relation = {
+      def join(otherPlan: Relation, joinType: JoinType, usingColumns: 
Seq[String]): Relation = {
         join(otherPlan, joinType, usingColumns, None)
       }
 
       private def join(
-          otherPlan: proto.Relation,
+          otherPlan: Relation,
           joinType: JoinType = JoinType.JOIN_TYPE_INNER,
           usingColumns: Seq[String],
-          condition: Option[proto.Expression]): proto.Relation = {
-        val relation = proto.Relation.newBuilder()
-        val join = proto.Join.newBuilder()
+          condition: Option[Expression]): Relation = {
+        val relation = Relation.newBuilder()
+        val join = Join.newBuilder()
         join
           .setLeft(logicalPlan)
           .setRight(otherPlan)
@@ -285,10 +280,10 @@ package object dsl {
         relation.setJoin(join).build()
       }
 
-      def as(alias: String): proto.Relation = {
-        proto.Relation
+      def as(alias: String): Relation = {
+        Relation
           .newBuilder(logicalPlan)
-          .setCommon(proto.RelationCommon.newBuilder().setAlias(alias))
+          .setCommon(RelationCommon.newBuilder().setAlias(alias))
           .build()
       }
 
@@ -296,24 +291,23 @@ package object dsl {
           lowerBound: Double,
           upperBound: Double,
           withReplacement: Boolean,
-          seed: Long): proto.Relation = {
-        proto.Relation
+          seed: Long): Relation = {
+        Relation
           .newBuilder()
           .setSample(
-            proto.Sample
+            Sample
               .newBuilder()
               .setInput(logicalPlan)
               .setUpperBound(upperBound)
               .setLowerBound(lowerBound)
               .setWithReplacement(withReplacement)
-              .setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build())
+              .setSeed(Sample.Seed.newBuilder().setSeed(seed).build())
               .build())
           .build()
       }
 
-      def groupBy(groupingExprs: proto.Expression*)(
-          aggregateExprs: proto.Expression*): proto.Relation = {
-        val agg = proto.Aggregate.newBuilder()
+      def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): 
Relation = {
+        val agg = Aggregate.newBuilder()
         agg.setInput(logicalPlan)
 
         for (groupingExpr <- groupingExprs) {
@@ -321,29 +315,26 @@ package object dsl {
         }
         // TODO: support aggregateExprs, which is blocked by supporting any 
builtin function
         // resolution only by name in the analyzer.
-        proto.Relation.newBuilder().setAggregate(agg.build()).build()
+        Relation.newBuilder().setAggregate(agg.build()).build()
       }
 
-      def except(otherPlan: proto.Relation, isAll: Boolean): proto.Relation = {
-        proto.Relation
+      def except(otherPlan: Relation, isAll: Boolean): Relation = {
+        Relation
           .newBuilder()
           .setSetOp(
             createSetOperation(logicalPlan, otherPlan, 
SetOpType.SET_OP_TYPE_EXCEPT, isAll))
           .build()
       }
 
-      def intersect(otherPlan: proto.Relation, isAll: Boolean): proto.Relation 
=
-        proto.Relation
+      def intersect(otherPlan: Relation, isAll: Boolean): Relation =
+        Relation
           .newBuilder()
           .setSetOp(
             createSetOperation(logicalPlan, otherPlan, 
SetOpType.SET_OP_TYPE_INTERSECT, isAll))
           .build()
 
-      def union(
-          otherPlan: proto.Relation,
-          isAll: Boolean = true,
-          byName: Boolean = false): proto.Relation =
-        proto.Relation
+      def union(otherPlan: Relation, isAll: Boolean = true, byName: Boolean = 
false): Relation =
+        Relation
           .newBuilder()
           .setSetOp(
             createSetOperation(
@@ -355,12 +346,12 @@ package object dsl {
           .build()
 
       private def createSetOperation(
-          left: proto.Relation,
-          right: proto.Relation,
+          left: Relation,
+          right: Relation,
           t: SetOpType,
           isAll: Boolean = true,
-          byName: Boolean = false): proto.SetOperation.Builder = {
-        val setOp = proto.SetOperation
+          byName: Boolean = false): SetOperation.Builder = {
+        val setOp = SetOperation
           .newBuilder()
           .setLeftInput(left)
           .setRightInput(right)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to