This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0def3de6ed1 [SPARK-42579][CONNECT] Part-1: `function.lit` support `Array[_]` dataType 0def3de6ed1 is described below commit 0def3de6ed1000efe72c8bbdd3b3804bb34ce620 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Mon Mar 6 19:26:39 2023 -0400 [SPARK-42579][CONNECT] Part-1: `function.lit` support `Array[_]` dataType ### What changes were proposed in this pull request? This is the first part of SPARK-42579, the pr is aims to support `Array[_]` data type for `function.lit`. ### Why are the changes needed? Make `function.lit` support nested dataType ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Add new test - Manually checked Scala 2.13 test ``` build/sbt "connect-client-jvm/test" -Phive -Pscala-2.13 build/sbt "connect/test" -Phive -Pscala-2.13 ``` Closes #40218 from LuciferYang/SPARK-42579. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../sql/expressions/LiteralProtoConverter.scala | 145 +++++++ .../scala/org/apache/spark/sql/functions.scala | 65 +-- .../apache/spark/sql/PlanGenerationTestSuite.scala | 37 ++ .../main/protobuf/spark/connect/expressions.proto | 6 + .../explain-results/function_lit_array.explain | 2 + .../query-tests/queries/function_lit_array.json | 461 +++++++++++++++++++++ .../queries/function_lit_array.proto.bin | Bin 0 -> 885 bytes .../planner/LiteralValueProtoConverter.scala | 65 +++ .../pyspark/sql/connect/proto/expressions_pb2.py | 79 ++-- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 38 ++ 10 files changed, 805 insertions(+), 93 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala new file mode 100644 index 00000000000..b3b9f53e7bb --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.expressions + +import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time._ + +import com.google.protobuf.ByteString + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.connect.client.unsupported +import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +object LiteralProtoConverter { + + private lazy val nullType = + proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build() + + /** + * Transforms literal value to the `proto.Expression.Literal.Builder`. + * + * @return + * proto.Expression.Literal.Builder + */ + @scala.annotation.tailrec + def toLiteralProtoBuilder(literal: Any): proto.Expression.Literal.Builder = { + val builder = proto.Expression.Literal.newBuilder() + + def decimalBuilder(precision: Int, scale: Int, value: String) = { + builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) + } + + def calendarIntervalBuilder(months: Int, days: Int, microseconds: Long) = { + builder.getCalendarIntervalBuilder + .setMonths(months) + .setDays(days) + .setMicroseconds(microseconds) + } + + def arrayBuilder(array: Array[_]) = { + val ab = builder.getArrayBuilder + .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) + array.foreach(x => ab.addElement(toLiteralProto(x))) + ab + } + + literal match { + case v: Boolean => builder.setBoolean(v) + case v: Byte => builder.setByte(v) + case v: Short => builder.setShort(v) + case v: Int => builder.setInteger(v) + case v: Long => builder.setLong(v) + case v: Float => builder.setFloat(v) + case v: Double => builder.setDouble(v) + case v: BigDecimal => + builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) + case v: JBigDecimal => + builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) + case v: String => builder.setString(v) + case v: Char => builder.setString(v.toString) + case v: Array[Char] => builder.setString(String.valueOf(v)) + case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) + case v: collection.mutable.WrappedArray[_] => toLiteralProtoBuilder(v.array) + case v: LocalDate => builder.setDate(v.toEpochDay.toInt) + case v: Decimal => + builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) + case v: Instant => builder.setTimestamp(DateTimeUtils.instantToMicros(v)) + case v: Timestamp => builder.setTimestamp(DateTimeUtils.fromJavaTimestamp(v)) + case v: LocalDateTime => builder.setTimestampNtz(DateTimeUtils.localDateTimeToMicros(v)) + case v: Date => builder.setDate(DateTimeUtils.fromJavaDate(v)) + case v: Duration => builder.setDayTimeInterval(IntervalUtils.durationToMicros(v)) + case v: Period => builder.setYearMonthInterval(IntervalUtils.periodToMonths(v)) + case v: Array[_] => builder.setArray(arrayBuilder(v)) + case v: CalendarInterval => + builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) + case null => builder.setNull(nullType) + case _ => unsupported(s"literal $literal not supported (yet).") + } + } + + /** + * Transforms literal value to the `proto.Expression.Literal`. + * + * @return + * proto.Expression.Literal + */ + private def toLiteralProto(literal: Any): proto.Expression.Literal = + toLiteralProtoBuilder(literal).build() + + private def toDataType(clz: Class[_]): DataType = clz match { + // primitive types + case JShort.TYPE => ShortType + case JInteger.TYPE => IntegerType + case JLong.TYPE => LongType + case JDouble.TYPE => DoubleType + case JByte.TYPE => ByteType + case JFloat.TYPE => FloatType + case JBoolean.TYPE => BooleanType + case JChar.TYPE => StringType + + // java classes + case _ if clz == classOf[LocalDate] || clz == classOf[Date] => DateType + case _ if clz == classOf[Instant] || clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[LocalDateTime] => TimestampNTZType + case _ if clz == classOf[Duration] => DayTimeIntervalType.DEFAULT + case _ if clz == classOf[Period] => YearMonthIntervalType.DEFAULT + case _ if clz == classOf[JBigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[Array[Char]] => StringType + case _ if clz == classOf[JShort] => ShortType + case _ if clz == classOf[JInteger] => IntegerType + case _ if clz == classOf[JLong] => LongType + case _ if clz == classOf[JDouble] => DoubleType + case _ if clz == classOf[JByte] => ByteType + case _ if clz == classOf[JFloat] => FloatType + case _ if clz == classOf[JBoolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] || clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + case _ if clz.isArray => ArrayType(toDataType(clz.getComponentType)) + case _ => + throw new UnsupportedOperationException(s"Unsupported component type $clz in arrays.") + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 76a27686bfd..8ce90886e0f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -16,24 +16,17 @@ */ package org.apache.spark.sql -import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.Collections import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.{typeTag, TypeTag} -import com.google.protobuf.ByteString - import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.connect.client.unsupported import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.expressions.LiteralProtoConverter._ +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.DataType.parseTypeWithFallback -import org.apache.spark.unsafe.types.CalendarInterval /** * Commonly used functions available for DataFrame operations. Using functions defined here @@ -93,32 +86,10 @@ object functions { */ def column(colName: String): Column = col(colName) - private def createLiteral(f: proto.Expression.Literal.Builder => Unit): Column = Column { - builder => - val literalBuilder = proto.Expression.Literal.newBuilder() - f(literalBuilder) - builder.setLiteral(literalBuilder) + private def createLiteral(literalBuilder: proto.Expression.Literal.Builder): Column = Column { + builder => builder.setLiteral(literalBuilder) } - private def createDecimalLiteral(precision: Int, scale: Int, value: String): Column = - createLiteral { builder => - builder.getDecimalBuilder - .setPrecision(precision) - .setScale(scale) - .setValue(value) - } - - private def createCalendarIntervalLiteral(months: Int, days: Int, microseconds: Long): Column = - createLiteral { builder => - builder.getCalendarIntervalBuilder - .setMonths(months) - .setDays(days) - .setMicroseconds(microseconds) - } - - private val nullType = - proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build() - /** * Creates a [[Column]] of literal value. * @@ -128,37 +99,11 @@ object functions { * * @since 3.4.0 */ - @scala.annotation.tailrec def lit(literal: Any): Column = { literal match { case c: Column => c case s: Symbol => Column(s.name) - case v: Boolean => createLiteral(_.setBoolean(v)) - case v: Byte => createLiteral(_.setByte(v)) - case v: Short => createLiteral(_.setShort(v)) - case v: Int => createLiteral(_.setInteger(v)) - case v: Long => createLiteral(_.setLong(v)) - case v: Float => createLiteral(_.setFloat(v)) - case v: Double => createLiteral(_.setDouble(v)) - case v: BigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString) - case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString) - case v: String => createLiteral(_.setString(v)) - case v: Char => createLiteral(_.setString(v.toString)) - case v: Array[Char] => createLiteral(_.setString(String.valueOf(v))) - case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v))) - case v: collection.mutable.WrappedArray[_] => lit(v.array) - case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt)) - case v: Decimal => createDecimalLiteral(Math.max(v.precision, v.scale), v.scale, v.toString) - case v: Instant => createLiteral(_.setTimestamp(DateTimeUtils.instantToMicros(v))) - case v: Timestamp => createLiteral(_.setTimestamp(DateTimeUtils.fromJavaTimestamp(v))) - case v: LocalDateTime => - createLiteral(_.setTimestampNtz(DateTimeUtils.localDateTimeToMicros(v))) - case v: Date => createLiteral(_.setDate(DateTimeUtils.fromJavaDate(v))) - case v: Duration => createLiteral(_.setDayTimeInterval(IntervalUtils.durationToMicros(v))) - case v: Period => createLiteral(_.setYearMonthInterval(IntervalUtils.periodToMonths(v))) - case v: CalendarInterval => createCalendarIntervalLiteral(v.months, v.days, v.microseconds) - case null => createLiteral(_.setNull(nullType)) - case _ => unsupported(s"literal $literal not supported (yet).") + case _ => createLiteral(toLiteralProtoBuilder(literal)) } } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index f5ffaf9b73a..85523a22d2b 100755 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2012,6 +2012,43 @@ class PlanGenerationTestSuite fn.lit(new CalendarInterval(2, 20, 100L))) } + test("function lit array") { + simple.select( + fn.lit(Array.emptyDoubleArray), + fn.lit(Array(Array(1), Array(2), Array(3))), + fn.lit(Array(Array(Array(1)), Array(Array(2)), Array(Array(3)))), + fn.lit(Array(true, false)), + fn.lit(Array(67.toByte, 68.toByte, 69.toByte)), + fn.lit(Array(9872.toShort, 9873.toShort, 9874.toShort)), + fn.lit(Array(-8726532, 8726532, -8726533)), + fn.lit(Array(7834609328726531L, 7834609328726532L, 7834609328726533L)), + fn.lit(Array(Math.E, 1.toDouble, 2.toDouble)), + fn.lit(Array(-0.8f, -0.7f, -0.9f)), + fn.lit(Array(BigDecimal(8997620, 5), BigDecimal(8997621, 5))), + fn.lit( + Array(BigDecimal(898897667231L, 7).bigDecimal, BigDecimal(898897667231L, 7).bigDecimal)), + fn.lit(Array("connect!", "disconnect!")), + fn.lit(Array('T', 'F')), + fn.lit( + Array( + Array.tabulate(10)(i => ('A' + i).toChar), + Array.tabulate(10)(i => ('B' + i).toChar))), + fn.lit(Array(java.time.LocalDate.of(2020, 10, 10), java.time.LocalDate.of(2020, 10, 11))), + fn.lit( + Array( + java.time.Instant.ofEpochMilli(1677155519808L), + java.time.Instant.ofEpochMilli(1677155519809L))), + fn.lit(Array(new java.sql.Timestamp(12345L), new java.sql.Timestamp(23456L))), + fn.lit( + Array( + java.time.LocalDateTime.of(2023, 2, 23, 20, 36), + java.time.LocalDateTime.of(2023, 2, 23, 21, 36))), + fn.lit(Array(java.sql.Date.valueOf("2023-02-23"), java.sql.Date.valueOf("2023-03-01"))), + fn.lit(Array(java.time.Duration.ofSeconds(100L), java.time.Duration.ofSeconds(200L))), + fn.lit(Array(java.time.Period.ofDays(100), java.time.Period.ofDays(200))), + fn.lit(Array(new CalendarInterval(2, 20, 100L), new CalendarInterval(2, 21, 200L)))) + } + /* Window API */ test("window") { simple.select( diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index e37a13ee959..6eb769ad27e 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -172,6 +172,7 @@ message Expression { CalendarInterval calendar_interval = 19; int32 year_month_interval = 20; int64 day_time_interval = 21; + Array array = 22; } message Decimal { @@ -189,6 +190,11 @@ message Expression { int32 days = 2; int64 microseconds = 3; } + + message Array { + DataType elementType = 1; + repeated Literal element = 2; + } } // An unresolved attribute that is not explicitly bound to a specific column, but the column diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain new file mode 100644 index 00000000000..74d512b6910 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain @@ -0,0 +1,2 @@ +Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, [-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, [7834609328726531,7834609328726532,7834609328726533] AS ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, [2.718281828459045,1.0, [...] ++- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json new file mode 100644 index 00000000000..c9441c9e77c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -0,0 +1,461 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "expressions": [{ + "literal": { + "array": { + "elementType": { + "double": { + } + } + } + } + }, { + "literal": { + "array": { + "elementType": { + "array": { + "elementType": { + "integer": { + } + }, + "containsNull": true + } + }, + "element": [{ + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 1 + }] + } + }, { + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 2 + }] + } + }, { + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 3 + }] + } + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "array": { + "elementType": { + "array": { + "elementType": { + "integer": { + } + }, + "containsNull": true + } + }, + "containsNull": true + } + }, + "element": [{ + "array": { + "elementType": { + "array": { + "elementType": { + "integer": { + } + }, + "containsNull": true + } + }, + "element": [{ + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 1 + }] + } + }] + } + }, { + "array": { + "elementType": { + "array": { + "elementType": { + "integer": { + } + }, + "containsNull": true + } + }, + "element": [{ + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 2 + }] + } + }] + } + }, { + "array": { + "elementType": { + "array": { + "elementType": { + "integer": { + } + }, + "containsNull": true + } + }, + "element": [{ + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": 3 + }] + } + }] + } + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "boolean": { + } + }, + "element": [{ + "boolean": true + }, { + "boolean": false + }] + } + } + }, { + "literal": { + "binary": "Q0RF" + } + }, { + "literal": { + "array": { + "elementType": { + "short": { + } + }, + "element": [{ + "short": 9872 + }, { + "short": 9873 + }, { + "short": 9874 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "integer": { + } + }, + "element": [{ + "integer": -8726532 + }, { + "integer": 8726532 + }, { + "integer": -8726533 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "long": { + } + }, + "element": [{ + "long": "7834609328726531" + }, { + "long": "7834609328726532" + }, { + "long": "7834609328726533" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "double": { + } + }, + "element": [{ + "double": 2.718281828459045 + }, { + "double": 1.0 + }, { + "double": 2.0 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "float": { + } + }, + "element": [{ + "float": -0.8 + }, { + "float": -0.7 + }, { + "float": -0.9 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "decimal": { + "scale": 18, + "precision": 38 + } + }, + "element": [{ + "decimal": { + "value": "89.97620", + "precision": 7, + "scale": 5 + } + }, { + "decimal": { + "value": "89.97621", + "precision": 7, + "scale": 5 + } + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "decimal": { + "scale": 18, + "precision": 38 + } + }, + "element": [{ + "decimal": { + "value": "89889.7667231", + "precision": 12, + "scale": 7 + } + }, { + "decimal": { + "value": "89889.7667231", + "precision": 12, + "scale": 7 + } + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "string": { + } + }, + "element": [{ + "string": "connect!" + }, { + "string": "disconnect!" + }] + } + } + }, { + "literal": { + "string": "TF" + } + }, { + "literal": { + "array": { + "elementType": { + "string": { + } + }, + "element": [{ + "string": "ABCDEFGHIJ" + }, { + "string": "BCDEFGHIJK" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "date": { + } + }, + "element": [{ + "date": 18545 + }, { + "date": 18546 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "timestamp": { + } + }, + "element": [{ + "timestamp": "1677155519808000" + }, { + "timestamp": "1677155519809000" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "timestamp": { + } + }, + "element": [{ + "timestamp": "12345000" + }, { + "timestamp": "23456000" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "timestampNtz": { + } + }, + "element": [{ + "timestampNtz": "1677184560000000" + }, { + "timestampNtz": "1677188160000000" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "date": { + } + }, + "element": [{ + "date": 19411 + }, { + "date": 19417 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "dayTimeInterval": { + "startField": 0, + "endField": 3 + } + }, + "element": [{ + "dayTimeInterval": "100000000" + }, { + "dayTimeInterval": "200000000" + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "yearMonthInterval": { + "startField": 0, + "endField": 1 + } + }, + "element": [{ + "yearMonthInterval": 0 + }, { + "yearMonthInterval": 0 + }] + } + } + }, { + "literal": { + "array": { + "elementType": { + "calendarInterval": { + } + }, + "element": [{ + "calendarInterval": { + "months": 2, + "days": 20, + "microseconds": "100" + } + }, { + "calendarInterval": { + "months": 2, + "days": 21, + "microseconds": "200" + } + }] + } + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin new file mode 100644 index 00000000000..9763bed6b50 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala index 6ddaabb1b88..79c489b9f5b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.connect.planner +import scala.collection.mutable +import scala.reflect.ClassTag + import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -97,6 +101,10 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + expressions.Literal.create( + toArrayData(lit.getArray), + ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType))) case _ => throw InvalidPlanInput( s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + @@ -130,4 +138,61 @@ object LiteralValueProtoConverter { case o => throw new Exception(s"Unsupported value type: $o") } } + + private def toArrayData(array: proto.Expression.Literal.Array): Any = { + def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit + tag: ClassTag[T]): Array[T] = { + val builder = mutable.ArrayBuilder.make[T] + val elementList = array.getElementList + builder.sizeHint(elementList.size()) + val iter = elementList.iterator() + while (iter.hasNext) { + builder += converter(iter.next()) + } + builder.result() + } + + val elementType = array.getElementType + if (elementType.hasShort) { + makeArrayData(v => v.getShort.toShort) + } else if (elementType.hasInteger) { + makeArrayData(v => v.getInteger) + } else if (elementType.hasLong) { + makeArrayData(v => v.getLong) + } else if (elementType.hasDouble) { + makeArrayData(v => v.getDouble) + } else if (elementType.hasByte) { + makeArrayData(v => v.getByte.toByte) + } else if (elementType.hasFloat) { + makeArrayData(v => v.getFloat) + } else if (elementType.hasBoolean) { + makeArrayData(v => v.getBoolean) + } else if (elementType.hasString) { + makeArrayData(v => v.getString) + } else if (elementType.hasBinary) { + makeArrayData(v => v.getBinary.toByteArray) + } else if (elementType.hasDate) { + makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate)) + } else if (elementType.hasTimestamp) { + makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp)) + } else if (elementType.hasTimestampNtz) { + makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz)) + } else if (elementType.hasDayTimeInterval) { + makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval)) + } else if (elementType.hasYearMonthInterval) { + makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval)) + } else if (elementType.hasDecimal) { + makeArrayData(v => Decimal(v.getDecimal.getValue)) + } else if (elementType.hasCalendarInterval) { + makeArrayData(v => { + val interval = v.getCalendarInterval + new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) + }) + } else if (elementType.hasArray) { + makeArrayData(v => toArrayData(v.getArray)) + } else { + throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)") + } + } + } diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 6e515235c7d..d0db2ad56cc 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe6%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...] ) @@ -49,6 +49,7 @@ _EXPRESSION_CAST = _EXPRESSION.nested_types_by_name["Cast"] _EXPRESSION_LITERAL = _EXPRESSION.nested_types_by_name["Literal"] _EXPRESSION_LITERAL_DECIMAL = _EXPRESSION_LITERAL.nested_types_by_name["Decimal"] _EXPRESSION_LITERAL_CALENDARINTERVAL = _EXPRESSION_LITERAL.nested_types_by_name["CalendarInterval"] +_EXPRESSION_LITERAL_ARRAY = _EXPRESSION_LITERAL.nested_types_by_name["Array"] _EXPRESSION_UNRESOLVEDATTRIBUTE = _EXPRESSION.nested_types_by_name["UnresolvedAttribute"] _EXPRESSION_UNRESOLVEDFUNCTION = _EXPRESSION.nested_types_by_name["UnresolvedFunction"] _EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionString"] @@ -142,6 +143,15 @@ Expression = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.CalendarInterval) }, ), + "Array": _reflection.GeneratedProtocolMessageType( + "Array", + (_message.Message,), + { + "DESCRIPTOR": _EXPRESSION_LITERAL_ARRAY, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Array) + }, + ), "DESCRIPTOR": _EXPRESSION_LITERAL, "__module__": "spark.connect.expressions_pb2" # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal) @@ -251,6 +261,7 @@ _sym_db.RegisterMessage(Expression.Cast) _sym_db.RegisterMessage(Expression.Literal) _sym_db.RegisterMessage(Expression.Literal.Decimal) _sym_db.RegisterMessage(Expression.Literal.CalendarInterval) +_sym_db.RegisterMessage(Expression.Literal.Array) _sym_db.RegisterMessage(Expression.UnresolvedAttribute) _sym_db.RegisterMessage(Expression.UnresolvedFunction) _sym_db.RegisterMessage(Expression.ExpressionString) @@ -300,7 +311,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 4943 + _EXPRESSION._serialized_end = 5137 _EXPRESSION_WINDOW._serialized_start = 1475 _EXPRESSION_WINDOW._serialized_end = 2258 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765 @@ -318,35 +329,37 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_CAST._serialized_start = 2689 _EXPRESSION_CAST._serialized_end = 2834 _EXPRESSION_LITERAL._serialized_start = 2837 - _EXPRESSION_LITERAL._serialized_end = 3713 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3480 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3597 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3827 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3830 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4034 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4036 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4086 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4258 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4261 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4393 - _EXPRESSION_UPDATEFIELDS._serialized_start = 4396 - _EXPRESSION_UPDATEFIELDS._serialized_end = 4583 - _EXPRESSION_ALIAS._serialized_start = 4585 - _EXPRESSION_ALIAS._serialized_end = 4705 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4708 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4866 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4868 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4930 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4946 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5257 - _PYTHONUDF._serialized_start = 5260 - _PYTHONUDF._serialized_end = 5390 - _SCALARSCALAUDF._serialized_start = 5393 - _SCALARSCALAUDF._serialized_end = 5577 + _EXPRESSION_LITERAL._serialized_end = 3907 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3545 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3662 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3664 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3762 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 3764 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 3891 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3909 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4021 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4024 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4228 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4230 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4280 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4282 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4364 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4366 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4452 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4455 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4587 + _EXPRESSION_UPDATEFIELDS._serialized_start = 4590 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4777 + _EXPRESSION_ALIAS._serialized_start = 4779 + _EXPRESSION_ALIAS._serialized_end = 4899 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4902 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5060 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5451 + _PYTHONUDF._serialized_start = 5454 + _PYTHONUDF._serialized_end = 5584 + _SCALARSCALAUDF._serialized_start = 5587 + _SCALARSCALAUDF._serialized_end = 5771 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 996de7fef2d..37db24ff91a 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -440,6 +440,35 @@ class Expression(google.protobuf.message.Message): ], ) -> None: ... + class Array(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ELEMENTTYPE_FIELD_NUMBER: builtins.int + ELEMENT_FIELD_NUMBER: builtins.int + @property + def elementType(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + @property + def element( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___Expression.Literal + ]: ... + def __init__( + self, + *, + elementType: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + element: collections.abc.Iterable[global___Expression.Literal] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["elementType", b"elementType"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "element", b"element", "elementType", b"elementType" + ], + ) -> None: ... + NULL_FIELD_NUMBER: builtins.int BINARY_FIELD_NUMBER: builtins.int BOOLEAN_FIELD_NUMBER: builtins.int @@ -457,6 +486,7 @@ class Expression(google.protobuf.message.Message): CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int + ARRAY_FIELD_NUMBER: builtins.int @property def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... binary: builtins.bytes @@ -480,6 +510,8 @@ class Expression(google.protobuf.message.Message): def calendar_interval(self) -> global___Expression.Literal.CalendarInterval: ... year_month_interval: builtins.int day_time_interval: builtins.int + @property + def array(self) -> global___Expression.Literal.Array: ... def __init__( self, *, @@ -500,10 +532,13 @@ class Expression(google.protobuf.message.Message): calendar_interval: global___Expression.Literal.CalendarInterval | None = ..., year_month_interval: builtins.int = ..., day_time_interval: builtins.int = ..., + array: global___Expression.Literal.Array | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ + "array", + b"array", "binary", b"binary", "boolean", @@ -545,6 +580,8 @@ class Expression(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ + "array", + b"array", "binary", b"binary", "boolean", @@ -603,6 +640,7 @@ class Expression(google.protobuf.message.Message): "calendar_interval", "year_month_interval", "day_time_interval", + "array", ] | None: ... class UnresolvedAttribute(google.protobuf.message.Message): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org