This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ec5d5471856 [SPARK-41349][CONNECT] Implement DataFrame.hint
ec5d5471856 is described below

commit ec5d547185645126dee87470835ea1d55936dcd0
Author: dengziming <dengzim...@bytedance.com>
AuthorDate: Wed Dec 7 16:08:09 2022 +0800

    [SPARK-41349][CONNECT] Implement DataFrame.hint
    
    ### What changes were proposed in this pull request?
    
    1. Implement `DataFrame.hint` for scala API
    2. Implement `DataFrame.hint` for python API
    
    ### Why are the changes needed?
    
    API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #38899 from dengziming/SPARK-41349.
    
    Authored-by: dengziming <dengzim...@bytedance.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/connect/dsl/package.scala |  33 +++--
 .../planner/LiteralValueProtoConverter.scala       | 157 +++++++++++++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 121 ++--------------
 .../planner/LiteralValueProtoConverterSuite.scala  |  32 +++++
 .../connect/planner/SparkConnectPlannerSuite.scala |  46 ++++++
 .../connect/planner/SparkConnectProtoSuite.scala   |   4 +
 6 files changed, 270 insertions(+), 123 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 8b1d69e03db..ec2d0cad95b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -26,6 +26,7 @@ import org.apache.spark.connect.proto.Join.JoinType
 import org.apache.spark.connect.proto.SetOperation.SetOpType
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.connect.planner.DataTypeProtoConverter
+import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
 
 /**
  * A collection of implicit conversions that create a DSL for constructing 
connect protos.
@@ -241,16 +242,6 @@ package object dsl {
 
     implicit class DslNAFunctions(val logicalPlan: Relation) {
 
-      private def convertValue(value: Any) = {
-        value match {
-          case b: Boolean => 
Expression.Literal.newBuilder().setBoolean(b).build()
-          case l: Long => Expression.Literal.newBuilder().setLong(l).build()
-          case d: Double => 
Expression.Literal.newBuilder().setDouble(d).build()
-          case s: String => 
Expression.Literal.newBuilder().setString(s).build()
-          case o => throw new Exception(s"Unsupported value type: $o")
-        }
-      }
-
       def fillValue(value: Any): Relation = {
         Relation
           .newBuilder()
@@ -258,7 +249,7 @@ package object dsl {
             proto.NAFill
               .newBuilder()
               .setInput(logicalPlan)
-              .addAllValues(Seq(convertValue(value)).asJava)
+              .addAllValues(Seq(toConnectProtoValue(value)).asJava)
               .build())
           .build()
       }
@@ -271,13 +262,13 @@ package object dsl {
               .newBuilder()
               .setInput(logicalPlan)
               .addAllCols(cols.toSeq.asJava)
-              .addAllValues(Seq(convertValue(value)).asJava)
+              .addAllValues(Seq(toConnectProtoValue(value)).asJava)
               .build())
           .build()
       }
 
       def fillValueMap(valueMap: Map[String, Any]): Relation = {
-        val (cols, values) = valueMap.mapValues(convertValue).toSeq.unzip
+        val (cols, values) = 
valueMap.mapValues(toConnectProtoValue).toSeq.unzip
         Relation
           .newBuilder()
           .setFillNa(
@@ -338,8 +329,8 @@ package object dsl {
           replace.addReplacements(
             proto.NAReplace.Replacement
               .newBuilder()
-              .setOldValue(convertValue(oldValue))
-              .setNewValue(convertValue(newValue)))
+              .setOldValue(toConnectProtoValue(oldValue))
+              .setNewValue(toConnectProtoValue(newValue)))
         }
 
         Relation
@@ -694,6 +685,18 @@ package object dsl {
           .build()
       }
 
+      def hint(name: String, parameters: Any*): Relation = {
+        Relation
+          .newBuilder()
+          .setHint(
+            Hint
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setName(name)
+              .addAllParameters(parameters.map(toConnectProtoValue).asJava))
+          .build()
+      }
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
new file mode 100644
index 00000000000..5a54ad9ac64
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.{expressions, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, 
CreateStruct}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+object LiteralValueProtoConverter {
+
+  /**
+   * Transforms the protocol buffers literals into the appropriate Catalyst 
literal expression.
+   *
+   * @return
+   *   Expression
+   */
+  def toCatalystExpression(lit: proto.Expression.Literal): 
expressions.Expression = {
+    lit.getLiteralTypeCase match {
+      case proto.Expression.Literal.LiteralTypeCase.NULL =>
+        expressions.Literal(null, NullType)
+
+      case proto.Expression.Literal.LiteralTypeCase.BINARY =>
+        expressions.Literal(lit.getBinary.toByteArray, BinaryType)
+
+      case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
+        expressions.Literal(lit.getBoolean, BooleanType)
+
+      case proto.Expression.Literal.LiteralTypeCase.BYTE =>
+        expressions.Literal(lit.getByte.toByte, ByteType)
+
+      case proto.Expression.Literal.LiteralTypeCase.SHORT =>
+        expressions.Literal(lit.getShort.toShort, ShortType)
+
+      case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
+        expressions.Literal(lit.getInteger, IntegerType)
+
+      case proto.Expression.Literal.LiteralTypeCase.LONG =>
+        expressions.Literal(lit.getLong, LongType)
+
+      case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
+        expressions.Literal(lit.getFloat, FloatType)
+
+      case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+        expressions.Literal(lit.getDouble, DoubleType)
+
+      case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
+        val decimal = Decimal.apply(lit.getDecimal.getValue)
+        var precision = decimal.precision
+        if (lit.getDecimal.hasPrecision) {
+          precision = math.max(precision, lit.getDecimal.getPrecision)
+        }
+        var scale = decimal.scale
+        if (lit.getDecimal.hasScale) {
+          scale = math.max(scale, lit.getDecimal.getScale)
+        }
+        expressions.Literal(decimal, DecimalType(math.max(precision, scale), 
scale))
+
+      case proto.Expression.Literal.LiteralTypeCase.STRING =>
+        expressions.Literal(UTF8String.fromString(lit.getString), StringType)
+
+      case proto.Expression.Literal.LiteralTypeCase.DATE =>
+        expressions.Literal(lit.getDate, DateType)
+
+      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
+        expressions.Literal(lit.getTimestamp, TimestampType)
+
+      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
+        expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
+
+      case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
+        val interval = new CalendarInterval(
+          lit.getCalendarInterval.getMonths,
+          lit.getCalendarInterval.getDays,
+          lit.getCalendarInterval.getMicroseconds)
+        expressions.Literal(interval, CalendarIntervalType)
+
+      case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
+        expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
+
+      case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
+        expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+
+      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+        val literals = 
lit.getArray.getValuesList.asScala.toArray.map(toCatalystExpression)
+        CreateArray(literals)
+
+      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+        val literals = 
lit.getStruct.getFieldsList.asScala.toArray.map(toCatalystExpression)
+        CreateStruct(literals)
+
+      case proto.Expression.Literal.LiteralTypeCase.MAP =>
+        val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair 
=>
+          toCatalystExpression(pair.getKey) :: 
toCatalystExpression(pair.getValue) :: Nil
+        }
+        CreateMap(literals)
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
+            s"(${lit.getLiteralTypeCase.name})")
+    }
+  }
+
+  def toCatalystValue(lit: proto.Expression.Literal): Any = {
+    lit.getLiteralTypeCase match {
+      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+        lit.getArray.getValuesList.asScala.toArray.map(toCatalystValue)
+
+      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+        val literals = 
lit.getStruct.getFieldsList.asScala.map(toCatalystValue).toSeq
+        InternalRow(literals: _*)
+
+      case proto.Expression.Literal.LiteralTypeCase.MAP =>
+        lit.getMap.getPairsList.asScala.toArray.map { pair =>
+          toCatalystValue(pair.getKey) -> toCatalystValue(pair.getValue)
+        }.toMap
+
+      case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString
+
+      case _ => 
toCatalystExpression(lit).asInstanceOf[expressions.Literal].value
+    }
+  }
+
+  def toConnectProtoValue(value: Any): proto.Expression.Literal = {
+    value match {
+      case null => proto.Expression.Literal.newBuilder().setNull(true).build()
+      case b: Boolean => 
proto.Expression.Literal.newBuilder().setBoolean(b).build()
+      case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build()
+      case s: Short => 
proto.Expression.Literal.newBuilder().setShort(s).build()
+      case i: Int => 
proto.Expression.Literal.newBuilder().setInteger(i).build()
+      case l: Long => proto.Expression.Literal.newBuilder().setLong(l).build()
+      case f: Float => 
proto.Expression.Literal.newBuilder().setFloat(f).build()
+      case d: Double => 
proto.Expression.Literal.newBuilder().setDouble(d).build()
+      case s: String => 
proto.Expression.Literal.newBuilder().setString(s).build()
+      case o => throw new Exception(s"Unsupported value type: $o")
+    }
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 982f9188e1d..d8b7843fbe7 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -33,15 +33,15 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
-import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, 
Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, 
Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, 
UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression,
 toCatalystValue}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.command.CreateViewCommand
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 import org.apache.spark.util.Utils
 
 final case class InvalidPlanInput(
@@ -93,6 +93,7 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP =>
         
transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap)
       case proto.Relation.RelTypeCase.WITH_COLUMNS => 
transformWithColumns(rel.getWithColumns)
+      case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
       case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -199,17 +200,7 @@ class SparkConnectPlanner(session: SparkSession) {
     } else {
       val valueMap = mutable.Map.empty[String, Any]
       cols.zip(values).foreach { case (col, value) =>
-        value.getLiteralTypeCase match {
-          case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
-            valueMap.update(col, value.getBoolean)
-          case proto.Expression.Literal.LiteralTypeCase.LONG =>
-            valueMap.update(col, value.getLong)
-          case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
-            valueMap.update(col, value.getDouble)
-          case proto.Expression.Literal.LiteralTypeCase.STRING =>
-            valueMap.update(col, value.getString)
-          case other => throw InvalidPlanInput(s"Unsupported value type: 
$other")
-        }
+        valueMap.update(col, toCatalystValue(value))
       }
       dataset.na.fill(valueMap = valueMap.toMap).logicalPlan
     }
@@ -233,19 +224,11 @@ class SparkConnectPlanner(session: SparkSession) {
   }
 
   private def transformReplace(rel: proto.NAReplace): LogicalPlan = {
-    def convert(value: proto.Expression.Literal): Any = {
-      value.getLiteralTypeCase match {
-        case proto.Expression.Literal.LiteralTypeCase.NULL => null
-        case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => 
value.getBoolean
-        case proto.Expression.Literal.LiteralTypeCase.DOUBLE => value.getDouble
-        case proto.Expression.Literal.LiteralTypeCase.STRING => value.getString
-        case other => throw InvalidPlanInput(s"Unsupported value type: $other")
-      }
-    }
-
     val replacement = mutable.Map.empty[Any, Any]
     rel.getReplacementsList.asScala.foreach { replace =>
-      replacement.update(convert(replace.getOldValue), 
convert(replace.getNewValue))
+      replacement.update(
+        toCatalystValue(replace.getOldValue),
+        toCatalystValue(replace.getNewValue))
     }
 
     if (rel.getColsCount == 0) {
@@ -313,6 +296,11 @@ class SparkConnectPlanner(session: SparkSession) {
       .logicalPlan
   }
 
+  private def transformHint(rel: proto.Hint): LogicalPlan = {
+    val params = rel.getParametersList.asScala.map(toCatalystValue).toSeq
+    UnresolvedHint(rel.getName, params, transformRelation(rel.getInput))
+  }
+
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Deduplicate needs a plan input")
@@ -426,90 +414,7 @@ class SparkConnectPlanner(session: SparkSession) {
    *   Expression
    */
   private def transformLiteral(lit: proto.Expression.Literal): Expression = {
-    lit.getLiteralTypeCase match {
-      case proto.Expression.Literal.LiteralTypeCase.NULL =>
-        expressions.Literal(null, NullType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BINARY =>
-        expressions.Literal(lit.getBinary.toByteArray, BinaryType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
-        expressions.Literal(lit.getBoolean, BooleanType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BYTE =>
-        expressions.Literal(lit.getByte, ByteType)
-
-      case proto.Expression.Literal.LiteralTypeCase.SHORT =>
-        expressions.Literal(lit.getShort, ShortType)
-
-      case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
-        expressions.Literal(lit.getInteger, IntegerType)
-
-      case proto.Expression.Literal.LiteralTypeCase.LONG =>
-        expressions.Literal(lit.getLong, LongType)
-
-      case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
-        expressions.Literal(lit.getFloat, FloatType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
-        expressions.Literal(lit.getDouble, DoubleType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
-        val decimal = Decimal.apply(lit.getDecimal.getValue)
-        var precision = decimal.precision
-        if (lit.getDecimal.hasPrecision) {
-          precision = math.max(precision, lit.getDecimal.getPrecision)
-        }
-        var scale = decimal.scale
-        if (lit.getDecimal.hasScale) {
-          scale = math.max(scale, lit.getDecimal.getScale)
-        }
-        expressions.Literal(decimal, DecimalType(math.max(precision, scale), 
scale))
-
-      case proto.Expression.Literal.LiteralTypeCase.STRING =>
-        expressions.Literal(UTF8String.fromString(lit.getString), StringType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DATE =>
-        expressions.Literal(lit.getDate, DateType)
-
-      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
-        expressions.Literal(lit.getTimestamp, TimestampType)
-
-      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
-        expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
-
-      case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
-        val interval = new CalendarInterval(
-          lit.getCalendarInterval.getMonths,
-          lit.getCalendarInterval.getDays,
-          lit.getCalendarInterval.getMicroseconds)
-        expressions.Literal(interval, CalendarIntervalType)
-
-      case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
-        expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
-
-      case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
-        expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
-
-      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
-        val literals = 
lit.getArray.getValuesList.asScala.toArray.map(transformLiteral)
-        CreateArray(literals)
-
-      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        val literals = 
lit.getStruct.getFieldsList.asScala.toArray.map(transformLiteral)
-        CreateStruct(literals)
-
-      case proto.Expression.Literal.LiteralTypeCase.MAP =>
-        val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair 
=>
-          transformLiteral(pair.getKey) :: transformLiteral(pair.getValue) :: 
Nil
-        }
-        CreateMap(literals)
-
-      case _ =>
-        throw InvalidPlanInput(
-          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
-            s"(${lit.getLiteralTypeCase.name})")
-    }
+    toCatalystExpression(lit)
   }
 
   private def transformLimit(limit: proto.Limit): LogicalPlan = {
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
new file mode 100644
index 00000000000..dc8254c47f3
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue,
 toConnectProtoValue}
+
+class LiteralValueProtoConverterSuite extends AnyFunSuite {
+
+  test("basic proto value and catalyst value conversion") {
+    val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, 
"spark")
+    for (v <- values) {
+      assertResult(v)(toCatalystValue(toConnectProtoValue(v)))
+    }
+  }
+}
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 362973a90ef..5362453da50 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
UnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.logical
+import 
org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
@@ -571,4 +572,49 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
       Dataset.ofRows(spark, 
transform(proto.Relation.newBuilder.setProject(project).build()))
     assert(df.schema.fields.toSeq.map(_.name) == Seq("id"))
   }
+
+  test("Hint") {
+    val input = proto.Relation
+      .newBuilder()
+      .setSql(
+        proto.SQL
+          .newBuilder()
+          .setQuery("select id from range(10)")
+          .build())
+
+    val logical = transform(
+      proto.Relation
+        .newBuilder()
+        .setHint(
+          proto.Hint
+            .newBuilder()
+            .setInput(input)
+            .setName("REPARTITION")
+            .addParameters(toConnectProtoValue(10000)))
+        .build())
+
+    val df = Dataset.ofRows(spark, logical)
+    assert(df.rdd.partitions.length == 10000)
+  }
+
+  test("Hint with illegal name will be ignored") {
+    val input = proto.Relation
+      .newBuilder()
+      .setSql(
+        proto.SQL
+          .newBuilder()
+          .setQuery("select id from range(10)")
+          .build())
+
+    val logical = transform(
+      proto.Relation
+        .newBuilder()
+        .setHint(
+          proto.Hint
+            .newBuilder()
+            .setInput(input)
+            .setName("illegal"))
+        .build())
+    assert(10 === Dataset.ofRows(spark, logical).count())
+  }
 }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 1e4e18c3c8f..074372b6c8d 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -546,6 +546,10 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
       sparkTestRelation.select(col("id").cast(StringType)))
   }
 
+  test("Test Hint") {
+    comparePlans(connectTestRelation.hint("COALESCE", 3), 
sparkTestRelation.hint("COALESCE", 3))
+  }
+
   private def createLocalRelationProtoByAttributeReferences(
       attrs: Seq[AttributeReference]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to