This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ec5d5471856 [SPARK-41349][CONNECT] Implement DataFrame.hint ec5d5471856 is described below commit ec5d547185645126dee87470835ea1d55936dcd0 Author: dengziming <dengzim...@bytedance.com> AuthorDate: Wed Dec 7 16:08:09 2022 +0800 [SPARK-41349][CONNECT] Implement DataFrame.hint ### What changes were proposed in this pull request? 1. Implement `DataFrame.hint` for scala API 2. Implement `DataFrame.hint` for python API ### Why are the changes needed? API coverage ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38899 from dengziming/SPARK-41349. Authored-by: dengziming <dengzim...@bytedance.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/sql/connect/dsl/package.scala | 33 +++-- .../planner/LiteralValueProtoConverter.scala | 157 +++++++++++++++++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 121 ++-------------- .../planner/LiteralValueProtoConverterSuite.scala | 32 +++++ .../connect/planner/SparkConnectPlannerSuite.scala | 46 ++++++ .../connect/planner/SparkConnectProtoSuite.scala | 4 + 6 files changed, 270 insertions(+), 123 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 8b1d69e03db..ec2d0cad95b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -26,6 +26,7 @@ import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.connect.proto.SetOperation.SetOpType import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connect.planner.DataTypeProtoConverter +import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue /** * A collection of implicit conversions that create a DSL for constructing connect protos. @@ -241,16 +242,6 @@ package object dsl { implicit class DslNAFunctions(val logicalPlan: Relation) { - private def convertValue(value: Any) = { - value match { - case b: Boolean => Expression.Literal.newBuilder().setBoolean(b).build() - case l: Long => Expression.Literal.newBuilder().setLong(l).build() - case d: Double => Expression.Literal.newBuilder().setDouble(d).build() - case s: String => Expression.Literal.newBuilder().setString(s).build() - case o => throw new Exception(s"Unsupported value type: $o") - } - } - def fillValue(value: Any): Relation = { Relation .newBuilder() @@ -258,7 +249,7 @@ package object dsl { proto.NAFill .newBuilder() .setInput(logicalPlan) - .addAllValues(Seq(convertValue(value)).asJava) + .addAllValues(Seq(toConnectProtoValue(value)).asJava) .build()) .build() } @@ -271,13 +262,13 @@ package object dsl { .newBuilder() .setInput(logicalPlan) .addAllCols(cols.toSeq.asJava) - .addAllValues(Seq(convertValue(value)).asJava) + .addAllValues(Seq(toConnectProtoValue(value)).asJava) .build()) .build() } def fillValueMap(valueMap: Map[String, Any]): Relation = { - val (cols, values) = valueMap.mapValues(convertValue).toSeq.unzip + val (cols, values) = valueMap.mapValues(toConnectProtoValue).toSeq.unzip Relation .newBuilder() .setFillNa( @@ -338,8 +329,8 @@ package object dsl { replace.addReplacements( proto.NAReplace.Replacement .newBuilder() - .setOldValue(convertValue(oldValue)) - .setNewValue(convertValue(newValue))) + .setOldValue(toConnectProtoValue(oldValue)) + .setNewValue(toConnectProtoValue(newValue))) } Relation @@ -694,6 +685,18 @@ package object dsl { .build() } + def hint(name: String, parameters: Any*): Relation = { + Relation + .newBuilder() + .setHint( + Hint + .newBuilder() + .setInput(logicalPlan) + .setName(name) + .addAllParameters(parameters.map(toConnectProtoValue).asJava)) + .build() + } + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala new file mode 100644 index 00000000000..5a54ad9ac64 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import scala.collection.JavaConverters._ + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, CreateStruct} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +object LiteralValueProtoConverter { + + /** + * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. + * + * @return + * Expression + */ + def toCatalystExpression(lit: proto.Expression.Literal): expressions.Expression = { + lit.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.NULL => + expressions.Literal(null, NullType) + + case proto.Expression.Literal.LiteralTypeCase.BINARY => + expressions.Literal(lit.getBinary.toByteArray, BinaryType) + + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + expressions.Literal(lit.getBoolean, BooleanType) + + case proto.Expression.Literal.LiteralTypeCase.BYTE => + expressions.Literal(lit.getByte.toByte, ByteType) + + case proto.Expression.Literal.LiteralTypeCase.SHORT => + expressions.Literal(lit.getShort.toShort, ShortType) + + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + expressions.Literal(lit.getInteger, IntegerType) + + case proto.Expression.Literal.LiteralTypeCase.LONG => + expressions.Literal(lit.getLong, LongType) + + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + expressions.Literal(lit.getFloat, FloatType) + + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + expressions.Literal(lit.getDouble, DoubleType) + + case proto.Expression.Literal.LiteralTypeCase.DECIMAL => + val decimal = Decimal.apply(lit.getDecimal.getValue) + var precision = decimal.precision + if (lit.getDecimal.hasPrecision) { + precision = math.max(precision, lit.getDecimal.getPrecision) + } + var scale = decimal.scale + if (lit.getDecimal.hasScale) { + scale = math.max(scale, lit.getDecimal.getScale) + } + expressions.Literal(decimal, DecimalType(math.max(precision, scale), scale)) + + case proto.Expression.Literal.LiteralTypeCase.STRING => + expressions.Literal(UTF8String.fromString(lit.getString), StringType) + + case proto.Expression.Literal.LiteralTypeCase.DATE => + expressions.Literal(lit.getDate, DateType) + + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + expressions.Literal(lit.getTimestamp, TimestampType) + + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + expressions.Literal(lit.getTimestampNtz, TimestampNTZType) + + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + val interval = new CalendarInterval( + lit.getCalendarInterval.getMonths, + lit.getCalendarInterval.getDays, + lit.getCalendarInterval.getMicroseconds) + expressions.Literal(interval, CalendarIntervalType) + + case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => + expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType()) + + case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => + expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType()) + + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + val literals = lit.getArray.getValuesList.asScala.toArray.map(toCatalystExpression) + CreateArray(literals) + + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + val literals = lit.getStruct.getFieldsList.asScala.toArray.map(toCatalystExpression) + CreateStruct(literals) + + case proto.Expression.Literal.LiteralTypeCase.MAP => + val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair => + toCatalystExpression(pair.getKey) :: toCatalystExpression(pair.getValue) :: Nil + } + CreateMap(literals) + + case _ => + throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") + } + } + + def toCatalystValue(lit: proto.Expression.Literal): Any = { + lit.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + lit.getArray.getValuesList.asScala.toArray.map(toCatalystValue) + + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + val literals = lit.getStruct.getFieldsList.asScala.map(toCatalystValue).toSeq + InternalRow(literals: _*) + + case proto.Expression.Literal.LiteralTypeCase.MAP => + lit.getMap.getPairsList.asScala.toArray.map { pair => + toCatalystValue(pair.getKey) -> toCatalystValue(pair.getValue) + }.toMap + + case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString + + case _ => toCatalystExpression(lit).asInstanceOf[expressions.Literal].value + } + } + + def toConnectProtoValue(value: Any): proto.Expression.Literal = { + value match { + case null => proto.Expression.Literal.newBuilder().setNull(true).build() + case b: Boolean => proto.Expression.Literal.newBuilder().setBoolean(b).build() + case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build() + case s: Short => proto.Expression.Literal.newBuilder().setShort(s).build() + case i: Int => proto.Expression.Literal.newBuilder().setInteger(i).build() + case l: Long => proto.Expression.Literal.newBuilder().setLong(l).build() + case f: Float => proto.Expression.Literal.newBuilder().setFloat(f).build() + case d: Double => proto.Expression.Literal.newBuilder().setDouble(d).build() + case s: String => proto.Expression.Literal.newBuilder().setString(s).build() + case o => throw new Exception(s"Unsupported value type: $o") + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 982f9188e1d..d8b7843fbe7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -33,15 +33,15 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, UnresolvedHint} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils final case class InvalidPlanInput( @@ -93,6 +93,7 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP => transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap) case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) + case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -199,17 +200,7 @@ class SparkConnectPlanner(session: SparkSession) { } else { val valueMap = mutable.Map.empty[String, Any] cols.zip(values).foreach { case (col, value) => - value.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - valueMap.update(col, value.getBoolean) - case proto.Expression.Literal.LiteralTypeCase.LONG => - valueMap.update(col, value.getLong) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - valueMap.update(col, value.getDouble) - case proto.Expression.Literal.LiteralTypeCase.STRING => - valueMap.update(col, value.getString) - case other => throw InvalidPlanInput(s"Unsupported value type: $other") - } + valueMap.update(col, toCatalystValue(value)) } dataset.na.fill(valueMap = valueMap.toMap).logicalPlan } @@ -233,19 +224,11 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformReplace(rel: proto.NAReplace): LogicalPlan = { - def convert(value: proto.Expression.Literal): Any = { - value.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.NULL => null - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => value.getBoolean - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => value.getDouble - case proto.Expression.Literal.LiteralTypeCase.STRING => value.getString - case other => throw InvalidPlanInput(s"Unsupported value type: $other") - } - } - val replacement = mutable.Map.empty[Any, Any] rel.getReplacementsList.asScala.foreach { replace => - replacement.update(convert(replace.getOldValue), convert(replace.getNewValue)) + replacement.update( + toCatalystValue(replace.getOldValue), + toCatalystValue(replace.getNewValue)) } if (rel.getColsCount == 0) { @@ -313,6 +296,11 @@ class SparkConnectPlanner(session: SparkSession) { .logicalPlan } + private def transformHint(rel: proto.Hint): LogicalPlan = { + val params = rel.getParametersList.asScala.map(toCatalystValue).toSeq + UnresolvedHint(rel.getName, params, transformRelation(rel.getInput)) + } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Deduplicate needs a plan input") @@ -426,90 +414,7 @@ class SparkConnectPlanner(session: SparkSession) { * Expression */ private def transformLiteral(lit: proto.Expression.Literal): Expression = { - lit.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.NULL => - expressions.Literal(null, NullType) - - case proto.Expression.Literal.LiteralTypeCase.BINARY => - expressions.Literal(lit.getBinary.toByteArray, BinaryType) - - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - expressions.Literal(lit.getBoolean, BooleanType) - - case proto.Expression.Literal.LiteralTypeCase.BYTE => - expressions.Literal(lit.getByte, ByteType) - - case proto.Expression.Literal.LiteralTypeCase.SHORT => - expressions.Literal(lit.getShort, ShortType) - - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - expressions.Literal(lit.getInteger, IntegerType) - - case proto.Expression.Literal.LiteralTypeCase.LONG => - expressions.Literal(lit.getLong, LongType) - - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - expressions.Literal(lit.getFloat, FloatType) - - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - expressions.Literal(lit.getDouble, DoubleType) - - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(lit.getDecimal.getValue) - var precision = decimal.precision - if (lit.getDecimal.hasPrecision) { - precision = math.max(precision, lit.getDecimal.getPrecision) - } - var scale = decimal.scale - if (lit.getDecimal.hasScale) { - scale = math.max(scale, lit.getDecimal.getScale) - } - expressions.Literal(decimal, DecimalType(math.max(precision, scale), scale)) - - case proto.Expression.Literal.LiteralTypeCase.STRING => - expressions.Literal(UTF8String.fromString(lit.getString), StringType) - - case proto.Expression.Literal.LiteralTypeCase.DATE => - expressions.Literal(lit.getDate, DateType) - - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - expressions.Literal(lit.getTimestamp, TimestampType) - - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - expressions.Literal(lit.getTimestampNtz, TimestampNTZType) - - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - val interval = new CalendarInterval( - lit.getCalendarInterval.getMonths, - lit.getCalendarInterval.getDays, - lit.getCalendarInterval.getMicroseconds) - expressions.Literal(interval, CalendarIntervalType) - - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType()) - - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType()) - - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val literals = lit.getArray.getValuesList.asScala.toArray.map(transformLiteral) - CreateArray(literals) - - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val literals = lit.getStruct.getFieldsList.asScala.toArray.map(transformLiteral) - CreateStruct(literals) - - case proto.Expression.Literal.LiteralTypeCase.MAP => - val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair => - transformLiteral(pair.getKey) :: transformLiteral(pair.getValue) :: Nil - } - CreateMap(literals) - - case _ => - throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") - } + toCatalystExpression(lit) } private def transformLimit(limit: proto.Limit): LogicalPlan = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala new file mode 100644 index 00000000000..dc8254c47f3 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue, toConnectProtoValue} + +class LiteralValueProtoConverterSuite extends AnyFunSuite { + + test("basic proto value and catalyst value conversion") { + val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark") + for (v <- values) { + assertResult(v)(toCatalystValue(toConnectProtoValue(v))) + } + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 362973a90ef..5362453da50 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -571,4 +572,49 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { Dataset.ofRows(spark, transform(proto.Relation.newBuilder.setProject(project).build())) assert(df.schema.fields.toSeq.map(_.name) == Seq("id")) } + + test("Hint") { + val input = proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery("select id from range(10)") + .build()) + + val logical = transform( + proto.Relation + .newBuilder() + .setHint( + proto.Hint + .newBuilder() + .setInput(input) + .setName("REPARTITION") + .addParameters(toConnectProtoValue(10000))) + .build()) + + val df = Dataset.ofRows(spark, logical) + assert(df.rdd.partitions.length == 10000) + } + + test("Hint with illegal name will be ignored") { + val input = proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery("select id from range(10)") + .build()) + + val logical = transform( + proto.Relation + .newBuilder() + .setHint( + proto.Hint + .newBuilder() + .setInput(input) + .setName("illegal")) + .build()) + assert(10 === Dataset.ofRows(spark, logical).count()) + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 1e4e18c3c8f..074372b6c8d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -546,6 +546,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { sparkTestRelation.select(col("id").cast(StringType))) } + test("Test Hint") { + comparePlans(connectTestRelation.hint("COALESCE", 3), sparkTestRelation.hint("COALESCE", 3)) + } + private def createLocalRelationProtoByAttributeReferences( attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org