This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4006d195111 [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes 4006d195111 is described below commit 4006d195111334b4b795680e547dea9dd0acda22 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Nov 25 09:26:25 2022 +0800 [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes ### What changes were proposed in this pull request? 1, in the sever side, make `proto_datatype` <-> `catalyst_datatype` conversion support all the built-in sql datatypes; 2, in the client side, make `proto_datatype` <-> `pyspark_catalyst_datatype` conversion support [all the datatypes that are supported in pyspark now.](https://github.com/apache/spark/blob/master/python/pyspark/sql/types.py#L60-L83) ### Why are the changes needed? right now, only `long`, `string`, `struct` are supported ``` grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with: status = StatusCode.UNKNOWN details = "Does not support convert float to connect proto types." debug_error_string = "{"created":"1669206685.760099000","description":"Error received from peer ipv6:[::1]:15002","file":"src/core/lib/surface/call.cc","file_line":1064,"grpc_message":"Does not support convert float to connect proto types.","grpc_status":2}" ``` this PR make the schema and literal expr support more datatypes. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added UT Closes #38770 from zhengruifeng/connect_support_more_datatypes. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 2 +- .../src/main/protobuf/spark/connect/types.proto | 123 +++-- .../org/apache/spark/sql/connect/dsl/package.scala | 5 +- .../connect/planner/DataTypeProtoConverter.scala | 281 ++++++++++-- .../connect/planner/SparkConnectServiceSuite.scala | 4 +- python/pyspark/sql/connect/client.py | 108 ++++- .../pyspark/sql/connect/proto/expressions_pb2.py | 66 +-- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 16 +- python/pyspark/sql/connect/proto/types_pb2.py | 261 ++++++----- python/pyspark/sql/connect/proto/types_pb2.pyi | 507 +++++++++++++-------- .../sql/tests/connect/test_connect_basic.py | 67 +++ 11 files changed, 972 insertions(+), 468 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index ac5fe24d349..7ff06aeb196 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -68,7 +68,7 @@ message Expression { bytes uuid = 28; DataType null = 29; // a typed null literal List list = 30; - DataType.List empty_list = 31; + DataType.Array empty_array = 31; DataType.Map empty_map = 32; UserDefined user_defined = 33; } diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto index ad043d85947..56dbf28665e 100644 --- a/connector/connect/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/src/main/protobuf/spark/connect/types.proto @@ -26,31 +26,46 @@ option java_package = "org.apache.spark.connect.proto"; // itself but only describes it. message DataType { oneof kind { - Boolean bool = 1; - I8 i8 = 2; - I16 i16 = 3; - I32 i32 = 5; - I64 i64 = 7; - FP32 fp32 = 10; - FP64 fp64 = 11; - String string = 12; - Binary binary = 13; - Timestamp timestamp = 14; - Date date = 16; - Time time = 17; - IntervalYear interval_year = 19; - IntervalDay interval_day = 20; - TimestampTZ timestamp_tz = 29; - UUID uuid = 32; - - FixedChar fixed_char = 21; - VarChar varchar = 22; - FixedBinary fixed_binary = 23; - Decimal decimal = 24; - - Struct struct = 25; - List list = 27; - Map map = 28; + NULL null = 1; + + Binary binary = 2; + + Boolean boolean = 3; + + // Numeric types + Byte byte = 4; + Short short = 5; + Integer integer = 6; + Long long = 7; + + Float float = 8; + Double double = 9; + Decimal decimal = 10; + + // String types + String string = 11; + Char char = 12; + VarChar var_char = 13; + + // Datatime types + Date date = 14; + Timestamp timestamp = 15; + TimestampNTZ timestamp_ntz = 16; + + // Interval types + CalendarInterval calendar_interval = 17; + YearMonthInterval year_month_interval = 18; + DayTimeInterval day_time_interval = 19; + + // Complex types + Array array = 20; + Struct struct = 21; + Map map = 22; + + + UUID uuid = 25; + + FixedBinary fixed_binary = 26; uint32 user_defined_type_reference = 31; } @@ -59,27 +74,27 @@ message DataType { uint32 type_variation_reference = 1; } - message I8 { + message Byte { uint32 type_variation_reference = 1; } - message I16 { + message Short { uint32 type_variation_reference = 1; } - message I32 { + message Integer { uint32 type_variation_reference = 1; } - message I64 { + message Long { uint32 type_variation_reference = 1; } - message FP32 { + message Float { uint32 type_variation_reference = 1; } - message FP64 { + message Double { uint32 type_variation_reference = 1; } @@ -91,6 +106,10 @@ message DataType { uint32 type_variation_reference = 1; } + message NULL { + uint32 type_variation_reference = 1; + } + message Timestamp { uint32 type_variation_reference = 1; } @@ -99,20 +118,24 @@ message DataType { uint32 type_variation_reference = 1; } - message Time { + message TimestampNTZ { uint32 type_variation_reference = 1; } - message TimestampTZ { + message CalendarInterval { uint32 type_variation_reference = 1; } - message IntervalYear { - uint32 type_variation_reference = 1; + message YearMonthInterval { + optional int32 start_field = 1; + optional int32 end_field = 2; + uint32 type_variation_reference = 3; } - message IntervalDay { - uint32 type_variation_reference = 1; + message DayTimeInterval { + optional int32 start_field = 1; + optional int32 end_field = 2; + uint32 type_variation_reference = 3; } message UUID { @@ -120,7 +143,7 @@ message DataType { } // Start compound types. - message FixedChar { + message Char { int32 length = 1; uint32 type_variation_reference = 2; } @@ -136,14 +159,14 @@ message DataType { } message Decimal { - int32 scale = 1; - int32 precision = 2; + optional int32 scale = 1; + optional int32 precision = 2; uint32 type_variation_reference = 3; } message StructField { - DataType type = 1; - string name = 2; + string name = 1; + DataType data_type = 2; bool nullable = 3; map<string, string> metadata = 4; } @@ -153,16 +176,16 @@ message DataType { uint32 type_variation_reference = 2; } - message List { - DataType DataType = 1; - uint32 type_variation_reference = 2; - bool element_nullable = 3; + message Array { + DataType element_type = 1; + bool contains_null = 2; + uint32 type_variation_reference = 3; } message Map { - DataType key = 1; - DataType value = 2; - uint32 type_variation_reference = 3; - bool value_nullable = 4; + DataType key_type = 1; + DataType value_type = 2; + bool value_contains_null = 3; + uint32 type_variation_reference = 4; } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 1827aa4e3c0..efebb67aeda 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -54,7 +54,8 @@ package object dsl { for (attr <- attrs) { val structField = DataType.StructField.newBuilder() structField.setName(attr.getName) - structField.setType(attr.getType) + structField.setDataType(attr.getType) + structField.setNullable(true) structExpr.addFields(structField) } Expression.QualifiedAttribute @@ -66,7 +67,7 @@ package object dsl { /** Creates a new AttributeReference of type int */ def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType( - DataType.newBuilder().setI32(DataType.I32.newBuilder()).build()) + DataType.newBuilder().setInteger(DataType.Integer.newBuilder()).build()) private def protoQualifiedAttrWithType(dataType: DataType): Expression.QualifiedAttribute = Expression.QualifiedAttribute diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index 088030b2dbc..0b8d79596c3 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._ import org.apache.spark.connect.proto import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ /** * This object offers methods to convert to/from connect proto to catalyst types. @@ -29,65 +29,264 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, Str object DataTypeProtoConverter { def toCatalystType(t: proto.DataType): DataType = { t.getKindCase match { - case proto.DataType.KindCase.I32 => IntegerType + case proto.DataType.KindCase.NULL => NullType + + case proto.DataType.KindCase.BINARY => BinaryType + + case proto.DataType.KindCase.BOOLEAN => BooleanType + + case proto.DataType.KindCase.BYTE => ByteType + case proto.DataType.KindCase.SHORT => ShortType + case proto.DataType.KindCase.INTEGER => IntegerType + case proto.DataType.KindCase.LONG => LongType + + case proto.DataType.KindCase.FLOAT => FloatType + case proto.DataType.KindCase.DOUBLE => DoubleType + case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal) + case proto.DataType.KindCase.STRING => StringType - case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct) - case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap) + case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength) + case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength) + + case proto.DataType.KindCase.DATE => DateType + case proto.DataType.KindCase.TIMESTAMP => TimestampType + case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType + + case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => + toCatalystYearMonthIntervalType(t.getYearMonthInterval) + case proto.DataType.KindCase.DAY_TIME_INTERVAL => + toCatalystDayTimeIntervalType(t.getDayTimeInterval) + + case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray) + case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct) + case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) case _ => throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.") } } - private def convertProtoDataTypeToCatalyst(t: proto.DataType.Struct): StructType = { - // TODO: handle nullability - val structFields = - t.getFieldsList.map(f => StructField(f.getName, toCatalystType(f.getType))).toList - StructType.apply(structFields) + private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = { + (t.hasPrecision, t.hasScale) match { + case (true, true) => DecimalType(t.getPrecision, t.getScale) + case (true, false) => new DecimalType(t.getPrecision) + case _ => new DecimalType() + } + } + + private def toCatalystYearMonthIntervalType(t: proto.DataType.YearMonthInterval) = { + (t.hasStartField, t.hasEndField) match { + case (true, true) => YearMonthIntervalType(t.getStartField.toByte, t.getEndField.toByte) + case (true, false) => YearMonthIntervalType(t.getStartField.toByte) + case _ => YearMonthIntervalType() + } } - private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = { - MapType(toCatalystType(t.getKey), toCatalystType(t.getValue)) + private def toCatalystDayTimeIntervalType(t: proto.DataType.DayTimeInterval) = { + (t.hasStartField, t.hasEndField) match { + case (true, true) => DayTimeIntervalType(t.getStartField.toByte, t.getEndField.toByte) + case (true, false) => DayTimeIntervalType(t.getStartField.toByte) + case _ => DayTimeIntervalType() + } + } + + private def toCatalystArrayType(t: proto.DataType.Array): ArrayType = { + ArrayType(toCatalystType(t.getElementType), t.getContainsNull) + } + + private def toCatalystStructType(t: proto.DataType.Struct): StructType = { + // TODO: support metadata + val fields = t.getFieldsList.toSeq.map { protoField => + StructField( + name = protoField.getName, + dataType = toCatalystType(protoField.getDataType), + nullable = protoField.getNullable, + metadata = Metadata.empty) + } + StructType.apply(fields) + } + + private def toCatalystMapType(t: proto.DataType.Map): MapType = { + MapType(toCatalystType(t.getKeyType), toCatalystType(t.getValueType), t.getValueContainsNull) } def toConnectProtoType(t: DataType): proto.DataType = { t match { + case NullType => + proto.DataType + .newBuilder() + .setNull(proto.DataType.NULL.getDefaultInstance) + .build() + + case BooleanType => + proto.DataType + .newBuilder() + .setBoolean(proto.DataType.Boolean.getDefaultInstance) + .build() + + case BinaryType => + proto.DataType + .newBuilder() + .setBinary(proto.DataType.Binary.getDefaultInstance) + .build() + + case ByteType => + proto.DataType + .newBuilder() + .setByte(proto.DataType.Byte.getDefaultInstance) + .build() + + case ShortType => + proto.DataType + .newBuilder() + .setShort(proto.DataType.Short.getDefaultInstance) + .build() + case IntegerType => - proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() - case StringType => - proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + proto.DataType + .newBuilder() + .setInteger(proto.DataType.Integer.getDefaultInstance) + .build() + case LongType => - proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build() - case struct: StructType => - toConnectProtoStructType(struct) - case map: MapType => toConnectProtoMapType(map) - case _ => - throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") - } - } + proto.DataType + .newBuilder() + .setLong(proto.DataType.Long.getDefaultInstance) + .build() - def toConnectProtoMapType(schema: MapType): proto.DataType = { - proto.DataType - .newBuilder() - .setMap( - proto.DataType.Map + case FloatType => + proto.DataType .newBuilder() - .setKey(toConnectProtoType(schema.keyType)) - .setValue(toConnectProtoType(schema.valueType)) - .build()) - .build() - } + .setFloat(proto.DataType.Float.getDefaultInstance) + .build() + + case DoubleType => + proto.DataType + .newBuilder() + .setDouble(proto.DataType.Double.getDefaultInstance) + .build() + + case DecimalType.Fixed(precision, scale) => + proto.DataType + .newBuilder() + .setDecimal( + proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) + .build() + + case StringType => + proto.DataType + .newBuilder() + .setString(proto.DataType.String.getDefaultInstance) + .build() + + case CharType(length) => + proto.DataType + .newBuilder() + .setChar(proto.DataType.Char.newBuilder().setLength(length).build()) + .build() + + case VarcharType(length) => + proto.DataType + .newBuilder() + .setVarChar(proto.DataType.VarChar.newBuilder().setLength(length).build()) + .build() + + case DateType => + proto.DataType + .newBuilder() + .setDate(proto.DataType.Date.getDefaultInstance) + .build() - def toConnectProtoStructType(schema: StructType): proto.DataType = { - val struct = proto.DataType.Struct.newBuilder() - for (structField <- schema.fields) { - struct.addFields( - proto.DataType.StructField + case TimestampType => + proto.DataType .newBuilder() - .setName(structField.name) - .setType(toConnectProtoType(structField.dataType)) - .setNullable(structField.nullable)) + .setTimestamp(proto.DataType.Timestamp.getDefaultInstance) + .build() + + case TimestampNTZType => + proto.DataType + .newBuilder() + .setTimestampNtz(proto.DataType.TimestampNTZ.getDefaultInstance) + .build() + + case CalendarIntervalType => + proto.DataType + .newBuilder() + .setCalendarInterval(proto.DataType.CalendarInterval.getDefaultInstance) + .build() + + case YearMonthIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setYearMonthInterval( + proto.DataType.YearMonthInterval + .newBuilder() + .setStartField(startField) + .setEndField(endField) + .build()) + .build() + + case DayTimeIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setDayTimeInterval( + proto.DataType.DayTimeInterval + .newBuilder() + .setStartField(startField) + .setEndField(endField) + .build()) + .build() + + case ArrayType(elementType: DataType, containsNull: Boolean) => + proto.DataType + .newBuilder() + .setArray( + proto.DataType.Array + .newBuilder() + .setElementType(toConnectProtoType(elementType)) + .setContainsNull(containsNull) + .build()) + .build() + + case StructType(fields: Array[StructField]) => + // TODO: support metadata + val protoFields = fields.toSeq.map { + case StructField( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata) => + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(toConnectProtoType(dataType)) + .setNullable(nullable) + .build() + } + proto.DataType + .newBuilder() + .setStruct( + proto.DataType.Struct + .newBuilder() + .addAllFields(protoFields) + .build()) + .build() + + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => + proto.DataType + .newBuilder() + .setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(toConnectProtoType(keyType)) + .setValueType(toConnectProtoType(valueType)) + .setValueContainsNull(valueContainsNull) + .build()) + .build() + + case _ => + throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } - proto.DataType.newBuilder().setStruct(struct).build() } def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 5f18b0d45c5..6ca3c2430c4 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -61,10 +61,10 @@ class SparkConnectServiceSuite extends SharedSparkSession { assert(schema.getFieldsCount == 2) assert( schema.getFields(0).getName == "col1" - && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32) + && schema.getFields(0).getDataType.getKindCase == proto.DataType.KindCase.INTEGER) assert( schema.getFields(1).getName == "col2" - && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING) + && schema.getFields(1).getDataType.getKindCase == proto.DataType.KindCase.STRING) } } diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 24d104a0418..eb2e2227fb9 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -32,7 +32,29 @@ from pyspark import cloudpickle from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.plan import SQL, Range -from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType +from pyspark.sql.types import ( + DataType, + ByteType, + ShortType, + IntegerType, + FloatType, + DateType, + TimestampType, + DayTimeIntervalType, + MapType, + StringType, + CharType, + VarcharType, + StructType, + StructField, + ArrayType, + DoubleType, + LongType, + DecimalType, + BinaryType, + BooleanType, + NullType, +) from typing import Iterable, Optional, Any, Union, List, Tuple, Dict @@ -356,38 +378,78 @@ class RemoteSparkSession(object): return self._execute_and_fetch(req) def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: - if schema.HasField("struct"): - structFields = [] - for proto_field in schema.struct.fields: - structFields.append( - StructField( - proto_field.name, - self._proto_schema_to_pyspark_schema(proto_field.type), - proto_field.nullable, - ) - ) - return StructType(structFields) - elif schema.HasField("i64"): + if schema.HasField("null"): + return NullType() + elif schema.HasField("boolean"): + return BooleanType() + elif schema.HasField("binary"): + return BinaryType() + elif schema.HasField("byte"): + return ByteType() + elif schema.HasField("short"): + return ShortType() + elif schema.HasField("integer"): + return IntegerType() + elif schema.HasField("long"): return LongType() + elif schema.HasField("float"): + return FloatType() + elif schema.HasField("double"): + return DoubleType() + elif schema.HasField("decimal"): + p = schema.decimal.precision if schema.decimal.HasField("precision") else 10 + s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 + return DecimalType(precision=p, scale=s) elif schema.HasField("string"): return StringType() + elif schema.HasField("char"): + return CharType(schema.char.length) + elif schema.HasField("var_char"): + return VarcharType(schema.var_char.length) + elif schema.HasField("date"): + return DateType() + elif schema.HasField("timestamp"): + return TimestampType() + elif schema.HasField("day_time_interval"): + return DayTimeIntervalType() + elif schema.HasField("array"): + return ArrayType( + self._proto_schema_to_pyspark_schema(schema.array.element_type), + schema.array.contains_null, + ) + elif schema.HasField("struct"): + fields = [ + StructField( + f.name, + self._proto_schema_to_pyspark_schema(f.data_type), + f.nullable, + ) + for f in schema.struct.fields + ] + return StructType(fields) + elif schema.HasField("map"): + return MapType( + self._proto_schema_to_pyspark_schema(schema.map.key_type), + self._proto_schema_to_pyspark_schema(schema.map.value_type), + schema.map.value_contains_null, + ) else: - raise Exception("Only support long, string, struct conversion") + raise Exception(f"Unsupported data type {schema}") def schema(self, plan: pb2.Plan) -> StructType: proto_schema = self._analyze(plan).schema # Server side should populate the struct field which is the schema. assert proto_schema.HasField("struct") - structFields = [] - for proto_field in proto_schema.struct.fields: - structFields.append( - StructField( - proto_field.name, - self._proto_schema_to_pyspark_schema(proto_field.type), - proto_field.nullable, - ) + + fields = [ + StructField( + f.name, + self._proto_schema_to_pyspark_schema(f.data_type), + f.nullable, ) - return StructType(structFields) + for f in proto_schema.struct.fields + ] + return StructType(fields) def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: result = self._analyze(plan, explain_mode) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index c372df7d324..7435a54d7ec 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf3\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...] ) @@ -235,37 +235,37 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 3161 + _EXPRESSION._serialized_end = 3164 _EXPRESSION_LITERAL._serialized_start = 613 - _EXPRESSION_LITERAL._serialized_end = 2696 - _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923 - _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1978 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1980 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2063 - _EXPRESSION_LITERAL_MAP._serialized_start = 2066 - _EXPRESSION_LITERAL_MAP._serialized_end = 2272 - _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2152 - _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2272 - _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2274 - _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2341 - _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2343 - _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2446 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 2448 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 2515 - _EXPRESSION_LITERAL_LIST._serialized_start = 2517 - _EXPRESSION_LITERAL_LIST._serialized_end = 2582 - _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2584 - _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2680 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2698 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2768 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2770 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2869 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2871 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2921 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2923 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2939 - _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941 - _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026 - _EXPRESSION_ALIAS._serialized_start = 3028 - _EXPRESSION_ALIAS._serialized_end = 3148 + _EXPRESSION_LITERAL._serialized_end = 2699 + _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1926 + _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1981 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1983 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2066 + _EXPRESSION_LITERAL_MAP._serialized_start = 2069 + _EXPRESSION_LITERAL_MAP._serialized_end = 2275 + _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2155 + _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2275 + _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2277 + _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2344 + _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2346 + _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2449 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 2451 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 2518 + _EXPRESSION_LITERAL_LIST._serialized_start = 2520 + _EXPRESSION_LITERAL_LIST._serialized_end = 2585 + _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2587 + _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2683 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2701 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2771 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2773 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2872 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2874 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2924 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2926 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2942 + _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2944 + _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3029 + _EXPRESSION_ALIAS._serialized_start = 3031 + _EXPRESSION_ALIAS._serialized_end = 3151 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index ea538b2ebec..05c8cbe6385 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -280,7 +280,7 @@ class Expression(google.protobuf.message.Message): UUID_FIELD_NUMBER: builtins.int NULL_FIELD_NUMBER: builtins.int LIST_FIELD_NUMBER: builtins.int - EMPTY_LIST_FIELD_NUMBER: builtins.int + EMPTY_ARRAY_FIELD_NUMBER: builtins.int EMPTY_MAP_FIELD_NUMBER: builtins.int USER_DEFINED_FIELD_NUMBER: builtins.int NULLABLE_FIELD_NUMBER: builtins.int @@ -323,7 +323,7 @@ class Expression(google.protobuf.message.Message): @property def list(self) -> global___Expression.Literal.List: ... @property - def empty_list(self) -> pyspark.sql.connect.proto.types_pb2.DataType.List: ... + def empty_array(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: ... @property def empty_map(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: ... @property @@ -365,7 +365,7 @@ class Expression(google.protobuf.message.Message): uuid: builtins.bytes = ..., null: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., list: global___Expression.Literal.List | None = ..., - empty_list: pyspark.sql.connect.proto.types_pb2.DataType.List | None = ..., + empty_array: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ..., empty_map: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ..., user_defined: global___Expression.Literal.UserDefined | None = ..., nullable: builtins.bool = ..., @@ -382,8 +382,8 @@ class Expression(google.protobuf.message.Message): b"date", "decimal", b"decimal", - "empty_list", - b"empty_list", + "empty_array", + b"empty_array", "empty_map", b"empty_map", "fixed_binary", @@ -443,8 +443,8 @@ class Expression(google.protobuf.message.Message): b"date", "decimal", b"decimal", - "empty_list", - b"empty_list", + "empty_array", + b"empty_array", "empty_map", b"empty_map", "fixed_binary", @@ -524,7 +524,7 @@ class Expression(google.protobuf.message.Message): "uuid", "null", "list", - "empty_list", + "empty_array", "empty_map", "user_defined", ] | None: ... diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py index 3507b03602c..dd6567d96a2 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.py +++ b/python/pyspark/sql/connect/proto/types_pb2.py @@ -30,35 +30,36 @@ _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xc1\x1c\n\x08\x44\x61taType\x12\x35\n\x04\x62ool\x18\x01 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x04\x62ool\x12,\n\x02i8\x18\x02 \x01(\x0b\x32\x1a.spark.connect.DataType.I8H\x00R\x02i8\x12/\n\x03i16\x18\x03 \x01(\x0b\x32\x1b.spark.connect.DataType.I16H\x00R\x03i16\x12/\n\x03i32\x18\x05 \x01(\x0b\x32\x1b.spark.connect.DataType.I32H\x00R\x03i32\x12/\n\x03i64\x18\x07 \x01(\x0b\x32\x1b.spark.connect.DataType.I64H\x00R\x [...] + b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xce \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...] ) _DATATYPE = DESCRIPTOR.message_types_by_name["DataType"] _DATATYPE_BOOLEAN = _DATATYPE.nested_types_by_name["Boolean"] -_DATATYPE_I8 = _DATATYPE.nested_types_by_name["I8"] -_DATATYPE_I16 = _DATATYPE.nested_types_by_name["I16"] -_DATATYPE_I32 = _DATATYPE.nested_types_by_name["I32"] -_DATATYPE_I64 = _DATATYPE.nested_types_by_name["I64"] -_DATATYPE_FP32 = _DATATYPE.nested_types_by_name["FP32"] -_DATATYPE_FP64 = _DATATYPE.nested_types_by_name["FP64"] +_DATATYPE_BYTE = _DATATYPE.nested_types_by_name["Byte"] +_DATATYPE_SHORT = _DATATYPE.nested_types_by_name["Short"] +_DATATYPE_INTEGER = _DATATYPE.nested_types_by_name["Integer"] +_DATATYPE_LONG = _DATATYPE.nested_types_by_name["Long"] +_DATATYPE_FLOAT = _DATATYPE.nested_types_by_name["Float"] +_DATATYPE_DOUBLE = _DATATYPE.nested_types_by_name["Double"] _DATATYPE_STRING = _DATATYPE.nested_types_by_name["String"] _DATATYPE_BINARY = _DATATYPE.nested_types_by_name["Binary"] +_DATATYPE_NULL = _DATATYPE.nested_types_by_name["NULL"] _DATATYPE_TIMESTAMP = _DATATYPE.nested_types_by_name["Timestamp"] _DATATYPE_DATE = _DATATYPE.nested_types_by_name["Date"] -_DATATYPE_TIME = _DATATYPE.nested_types_by_name["Time"] -_DATATYPE_TIMESTAMPTZ = _DATATYPE.nested_types_by_name["TimestampTZ"] -_DATATYPE_INTERVALYEAR = _DATATYPE.nested_types_by_name["IntervalYear"] -_DATATYPE_INTERVALDAY = _DATATYPE.nested_types_by_name["IntervalDay"] +_DATATYPE_TIMESTAMPNTZ = _DATATYPE.nested_types_by_name["TimestampNTZ"] +_DATATYPE_CALENDARINTERVAL = _DATATYPE.nested_types_by_name["CalendarInterval"] +_DATATYPE_YEARMONTHINTERVAL = _DATATYPE.nested_types_by_name["YearMonthInterval"] +_DATATYPE_DAYTIMEINTERVAL = _DATATYPE.nested_types_by_name["DayTimeInterval"] _DATATYPE_UUID = _DATATYPE.nested_types_by_name["UUID"] -_DATATYPE_FIXEDCHAR = _DATATYPE.nested_types_by_name["FixedChar"] +_DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"] _DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"] _DATATYPE_FIXEDBINARY = _DATATYPE.nested_types_by_name["FixedBinary"] _DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"] _DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"] _DATATYPE_STRUCTFIELD_METADATAENTRY = _DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"] _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"] -_DATATYPE_LIST = _DATATYPE.nested_types_by_name["List"] +_DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"] _DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"] DataType = _reflection.GeneratedProtocolMessageType( "DataType", @@ -73,58 +74,58 @@ DataType = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.DataType.Boolean) }, ), - "I8": _reflection.GeneratedProtocolMessageType( - "I8", + "Byte": _reflection.GeneratedProtocolMessageType( + "Byte", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_I8, + "DESCRIPTOR": _DATATYPE_BYTE, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.I8) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Byte) }, ), - "I16": _reflection.GeneratedProtocolMessageType( - "I16", + "Short": _reflection.GeneratedProtocolMessageType( + "Short", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_I16, + "DESCRIPTOR": _DATATYPE_SHORT, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.I16) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Short) }, ), - "I32": _reflection.GeneratedProtocolMessageType( - "I32", + "Integer": _reflection.GeneratedProtocolMessageType( + "Integer", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_I32, + "DESCRIPTOR": _DATATYPE_INTEGER, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.I32) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Integer) }, ), - "I64": _reflection.GeneratedProtocolMessageType( - "I64", + "Long": _reflection.GeneratedProtocolMessageType( + "Long", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_I64, + "DESCRIPTOR": _DATATYPE_LONG, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.I64) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Long) }, ), - "FP32": _reflection.GeneratedProtocolMessageType( - "FP32", + "Float": _reflection.GeneratedProtocolMessageType( + "Float", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_FP32, + "DESCRIPTOR": _DATATYPE_FLOAT, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP32) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Float) }, ), - "FP64": _reflection.GeneratedProtocolMessageType( - "FP64", + "Double": _reflection.GeneratedProtocolMessageType( + "Double", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_FP64, + "DESCRIPTOR": _DATATYPE_DOUBLE, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP64) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Double) }, ), "String": _reflection.GeneratedProtocolMessageType( @@ -145,6 +146,15 @@ DataType = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.DataType.Binary) }, ), + "NULL": _reflection.GeneratedProtocolMessageType( + "NULL", + (_message.Message,), + { + "DESCRIPTOR": _DATATYPE_NULL, + "__module__": "spark.connect.types_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.DataType.NULL) + }, + ), "Timestamp": _reflection.GeneratedProtocolMessageType( "Timestamp", (_message.Message,), @@ -163,40 +173,40 @@ DataType = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.DataType.Date) }, ), - "Time": _reflection.GeneratedProtocolMessageType( - "Time", + "TimestampNTZ": _reflection.GeneratedProtocolMessageType( + "TimestampNTZ", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_TIME, + "DESCRIPTOR": _DATATYPE_TIMESTAMPNTZ, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.Time) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampNTZ) }, ), - "TimestampTZ": _reflection.GeneratedProtocolMessageType( - "TimestampTZ", + "CalendarInterval": _reflection.GeneratedProtocolMessageType( + "CalendarInterval", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_TIMESTAMPTZ, + "DESCRIPTOR": _DATATYPE_CALENDARINTERVAL, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampTZ) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.CalendarInterval) }, ), - "IntervalYear": _reflection.GeneratedProtocolMessageType( - "IntervalYear", + "YearMonthInterval": _reflection.GeneratedProtocolMessageType( + "YearMonthInterval", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_INTERVALYEAR, + "DESCRIPTOR": _DATATYPE_YEARMONTHINTERVAL, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalYear) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.YearMonthInterval) }, ), - "IntervalDay": _reflection.GeneratedProtocolMessageType( - "IntervalDay", + "DayTimeInterval": _reflection.GeneratedProtocolMessageType( + "DayTimeInterval", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_INTERVALDAY, + "DESCRIPTOR": _DATATYPE_DAYTIMEINTERVAL, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalDay) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.DayTimeInterval) }, ), "UUID": _reflection.GeneratedProtocolMessageType( @@ -208,13 +218,13 @@ DataType = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.DataType.UUID) }, ), - "FixedChar": _reflection.GeneratedProtocolMessageType( - "FixedChar", + "Char": _reflection.GeneratedProtocolMessageType( + "Char", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_FIXEDCHAR, + "DESCRIPTOR": _DATATYPE_CHAR, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.FixedChar) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Char) }, ), "VarChar": _reflection.GeneratedProtocolMessageType( @@ -271,13 +281,13 @@ DataType = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.DataType.Struct) }, ), - "List": _reflection.GeneratedProtocolMessageType( - "List", + "Array": _reflection.GeneratedProtocolMessageType( + "Array", (_message.Message,), { - "DESCRIPTOR": _DATATYPE_LIST, + "DESCRIPTOR": _DATATYPE_ARRAY, "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.List) + # @@protoc_insertion_point(class_scope:spark.connect.DataType.Array) }, ), "Map": _reflection.GeneratedProtocolMessageType( @@ -296,29 +306,30 @@ DataType = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(DataType) _sym_db.RegisterMessage(DataType.Boolean) -_sym_db.RegisterMessage(DataType.I8) -_sym_db.RegisterMessage(DataType.I16) -_sym_db.RegisterMessage(DataType.I32) -_sym_db.RegisterMessage(DataType.I64) -_sym_db.RegisterMessage(DataType.FP32) -_sym_db.RegisterMessage(DataType.FP64) +_sym_db.RegisterMessage(DataType.Byte) +_sym_db.RegisterMessage(DataType.Short) +_sym_db.RegisterMessage(DataType.Integer) +_sym_db.RegisterMessage(DataType.Long) +_sym_db.RegisterMessage(DataType.Float) +_sym_db.RegisterMessage(DataType.Double) _sym_db.RegisterMessage(DataType.String) _sym_db.RegisterMessage(DataType.Binary) +_sym_db.RegisterMessage(DataType.NULL) _sym_db.RegisterMessage(DataType.Timestamp) _sym_db.RegisterMessage(DataType.Date) -_sym_db.RegisterMessage(DataType.Time) -_sym_db.RegisterMessage(DataType.TimestampTZ) -_sym_db.RegisterMessage(DataType.IntervalYear) -_sym_db.RegisterMessage(DataType.IntervalDay) +_sym_db.RegisterMessage(DataType.TimestampNTZ) +_sym_db.RegisterMessage(DataType.CalendarInterval) +_sym_db.RegisterMessage(DataType.YearMonthInterval) +_sym_db.RegisterMessage(DataType.DayTimeInterval) _sym_db.RegisterMessage(DataType.UUID) -_sym_db.RegisterMessage(DataType.FixedChar) +_sym_db.RegisterMessage(DataType.Char) _sym_db.RegisterMessage(DataType.VarChar) _sym_db.RegisterMessage(DataType.FixedBinary) _sym_db.RegisterMessage(DataType.Decimal) _sym_db.RegisterMessage(DataType.StructField) _sym_db.RegisterMessage(DataType.StructField.MetadataEntry) _sym_db.RegisterMessage(DataType.Struct) -_sym_db.RegisterMessage(DataType.List) +_sym_db.RegisterMessage(DataType.Array) _sym_db.RegisterMessage(DataType.Map) if _descriptor._USE_C_DESCRIPTORS == False: @@ -328,55 +339,57 @@ if _descriptor._USE_C_DESCRIPTORS == False: _DATATYPE_STRUCTFIELD_METADATAENTRY._options = None _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001" _DATATYPE._serialized_start = 45 - _DATATYPE._serialized_end = 3694 - _DATATYPE_BOOLEAN._serialized_start = 1461 - _DATATYPE_BOOLEAN._serialized_end = 1528 - _DATATYPE_I8._serialized_start = 1530 - _DATATYPE_I8._serialized_end = 1592 - _DATATYPE_I16._serialized_start = 1594 - _DATATYPE_I16._serialized_end = 1657 - _DATATYPE_I32._serialized_start = 1659 - _DATATYPE_I32._serialized_end = 1722 - _DATATYPE_I64._serialized_start = 1724 - _DATATYPE_I64._serialized_end = 1787 - _DATATYPE_FP32._serialized_start = 1789 - _DATATYPE_FP32._serialized_end = 1853 - _DATATYPE_FP64._serialized_start = 1855 - _DATATYPE_FP64._serialized_end = 1919 - _DATATYPE_STRING._serialized_start = 1921 - _DATATYPE_STRING._serialized_end = 1987 - _DATATYPE_BINARY._serialized_start = 1989 - _DATATYPE_BINARY._serialized_end = 2055 - _DATATYPE_TIMESTAMP._serialized_start = 2057 - _DATATYPE_TIMESTAMP._serialized_end = 2126 - _DATATYPE_DATE._serialized_start = 2128 - _DATATYPE_DATE._serialized_end = 2192 - _DATATYPE_TIME._serialized_start = 2194 - _DATATYPE_TIME._serialized_end = 2258 - _DATATYPE_TIMESTAMPTZ._serialized_start = 2260 - _DATATYPE_TIMESTAMPTZ._serialized_end = 2331 - _DATATYPE_INTERVALYEAR._serialized_start = 2333 - _DATATYPE_INTERVALYEAR._serialized_end = 2405 - _DATATYPE_INTERVALDAY._serialized_start = 2407 - _DATATYPE_INTERVALDAY._serialized_end = 2478 - _DATATYPE_UUID._serialized_start = 2480 - _DATATYPE_UUID._serialized_end = 2544 - _DATATYPE_FIXEDCHAR._serialized_start = 2546 - _DATATYPE_FIXEDCHAR._serialized_end = 2639 - _DATATYPE_VARCHAR._serialized_start = 2641 - _DATATYPE_VARCHAR._serialized_end = 2732 - _DATATYPE_FIXEDBINARY._serialized_start = 2734 - _DATATYPE_FIXEDBINARY._serialized_end = 2829 - _DATATYPE_DECIMAL._serialized_start = 2831 - _DATATYPE_DECIMAL._serialized_end = 2950 - _DATATYPE_STRUCTFIELD._serialized_start = 2953 - _DATATYPE_STRUCTFIELD._serialized_end = 3199 - _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3140 - _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3199 - _DATATYPE_STRUCT._serialized_start = 3201 - _DATATYPE_STRUCT._serialized_end = 3328 - _DATATYPE_LIST._serialized_start = 3331 - _DATATYPE_LIST._serialized_end = 3491 - _DATATYPE_MAP._serialized_start = 3494 - _DATATYPE_MAP._serialized_end = 3686 + _DATATYPE._serialized_end = 4219 + _DATATYPE_BOOLEAN._serialized_start = 1612 + _DATATYPE_BOOLEAN._serialized_end = 1679 + _DATATYPE_BYTE._serialized_start = 1681 + _DATATYPE_BYTE._serialized_end = 1745 + _DATATYPE_SHORT._serialized_start = 1747 + _DATATYPE_SHORT._serialized_end = 1812 + _DATATYPE_INTEGER._serialized_start = 1814 + _DATATYPE_INTEGER._serialized_end = 1881 + _DATATYPE_LONG._serialized_start = 1883 + _DATATYPE_LONG._serialized_end = 1947 + _DATATYPE_FLOAT._serialized_start = 1949 + _DATATYPE_FLOAT._serialized_end = 2014 + _DATATYPE_DOUBLE._serialized_start = 2016 + _DATATYPE_DOUBLE._serialized_end = 2082 + _DATATYPE_STRING._serialized_start = 2084 + _DATATYPE_STRING._serialized_end = 2150 + _DATATYPE_BINARY._serialized_start = 2152 + _DATATYPE_BINARY._serialized_end = 2218 + _DATATYPE_NULL._serialized_start = 2220 + _DATATYPE_NULL._serialized_end = 2284 + _DATATYPE_TIMESTAMP._serialized_start = 2286 + _DATATYPE_TIMESTAMP._serialized_end = 2355 + _DATATYPE_DATE._serialized_start = 2357 + _DATATYPE_DATE._serialized_end = 2421 + _DATATYPE_TIMESTAMPNTZ._serialized_start = 2423 + _DATATYPE_TIMESTAMPNTZ._serialized_end = 2495 + _DATATYPE_CALENDARINTERVAL._serialized_start = 2497 + _DATATYPE_CALENDARINTERVAL._serialized_end = 2573 + _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2576 + _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2755 + _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2758 + _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2935 + _DATATYPE_UUID._serialized_start = 2937 + _DATATYPE_UUID._serialized_end = 3001 + _DATATYPE_CHAR._serialized_start = 3003 + _DATATYPE_CHAR._serialized_end = 3091 + _DATATYPE_VARCHAR._serialized_start = 3093 + _DATATYPE_VARCHAR._serialized_end = 3184 + _DATATYPE_FIXEDBINARY._serialized_start = 3186 + _DATATYPE_FIXEDBINARY._serialized_end = 3281 + _DATATYPE_DECIMAL._serialized_start = 3284 + _DATATYPE_DECIMAL._serialized_end = 3437 + _DATATYPE_STRUCTFIELD._serialized_start = 3440 + _DATATYPE_STRUCTFIELD._serialized_end = 3695 + _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3636 + _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3695 + _DATATYPE_STRUCT._serialized_start = 3697 + _DATATYPE_STRUCT._serialized_end = 3824 + _DATATYPE_ARRAY._serialized_start = 3827 + _DATATYPE_ARRAY._serialized_end = 3989 + _DATATYPE_MAP._serialized_start = 3992 + _DATATYPE_MAP._serialized_end = 4211 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi index 3bf36fc790c..647f625659b 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.pyi +++ b/python/pyspark/sql/connect/proto/types_pb2.pyi @@ -39,6 +39,7 @@ import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message import sys +import typing if sys.version_info >= (3, 8): import typing as typing_extensions @@ -71,7 +72,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class I8(google.protobuf.message.Message): + class Byte(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -88,7 +89,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class I16(google.protobuf.message.Message): + class Short(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -105,7 +106,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class I32(google.protobuf.message.Message): + class Integer(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -122,7 +123,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class I64(google.protobuf.message.Message): + class Long(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -139,7 +140,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class FP32(google.protobuf.message.Message): + class Float(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -156,7 +157,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class FP64(google.protobuf.message.Message): + class Double(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -207,6 +208,23 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... + class NULL(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int + type_variation_reference: builtins.int + def __init__( + self, + *, + type_variation_reference: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "type_variation_reference", b"type_variation_reference" + ], + ) -> None: ... + class Timestamp(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -241,7 +259,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class Time(google.protobuf.message.Message): + class TimestampNTZ(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -258,7 +276,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class TimestampTZ(google.protobuf.message.Message): + class CalendarInterval(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int @@ -275,39 +293,111 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class IntervalYear(google.protobuf.message.Message): + class YearMonthInterval(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + START_FIELD_FIELD_NUMBER: builtins.int + END_FIELD_FIELD_NUMBER: builtins.int TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int + start_field: builtins.int + end_field: builtins.int type_variation_reference: builtins.int def __init__( self, *, + start_field: builtins.int | None = ..., + end_field: builtins.int | None = ..., type_variation_reference: builtins.int = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_end_field", + b"_end_field", + "_start_field", + b"_start_field", + "end_field", + b"end_field", + "start_field", + b"start_field", + ], + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "type_variation_reference", b"type_variation_reference" + "_end_field", + b"_end_field", + "_start_field", + b"_start_field", + "end_field", + b"end_field", + "start_field", + b"start_field", + "type_variation_reference", + b"type_variation_reference", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"] + ) -> typing_extensions.Literal["end_field"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"] + ) -> typing_extensions.Literal["start_field"] | None: ... - class IntervalDay(google.protobuf.message.Message): + class DayTimeInterval(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + START_FIELD_FIELD_NUMBER: builtins.int + END_FIELD_FIELD_NUMBER: builtins.int TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int + start_field: builtins.int + end_field: builtins.int type_variation_reference: builtins.int def __init__( self, *, + start_field: builtins.int | None = ..., + end_field: builtins.int | None = ..., type_variation_reference: builtins.int = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_end_field", + b"_end_field", + "_start_field", + b"_start_field", + "end_field", + b"end_field", + "start_field", + b"start_field", + ], + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "type_variation_reference", b"type_variation_reference" + "_end_field", + b"_end_field", + "_start_field", + b"_start_field", + "end_field", + b"end_field", + "start_field", + b"start_field", + "type_variation_reference", + b"type_variation_reference", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"] + ) -> typing_extensions.Literal["end_field"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"] + ) -> typing_extensions.Literal["start_field"] | None: ... class UUID(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -326,7 +416,7 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class FixedChar(google.protobuf.message.Message): + class Char(google.protobuf.message.Message): """Start compound types.""" DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -400,13 +490,30 @@ class DataType(google.protobuf.message.Message): def __init__( self, *, - scale: builtins.int = ..., - precision: builtins.int = ..., + scale: builtins.int | None = ..., + precision: builtins.int | None = ..., type_variation_reference: builtins.int = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_precision", + b"_precision", + "_scale", + b"_scale", + "precision", + b"precision", + "scale", + b"scale", + ], + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_precision", + b"_precision", + "_scale", + b"_scale", "precision", b"precision", "scale", @@ -415,6 +522,14 @@ class DataType(google.protobuf.message.Message): b"type_variation_reference", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_precision", b"_precision"] + ) -> typing_extensions.Literal["precision"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_scale", b"_scale"] + ) -> typing_extensions.Literal["scale"] | None: ... class StructField(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -436,13 +551,13 @@ class DataType(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] ) -> None: ... - TYPE_FIELD_NUMBER: builtins.int NAME_FIELD_NUMBER: builtins.int + DATA_TYPE_FIELD_NUMBER: builtins.int NULLABLE_FIELD_NUMBER: builtins.int METADATA_FIELD_NUMBER: builtins.int - @property - def type(self) -> global___DataType: ... name: builtins.str + @property + def data_type(self) -> global___DataType: ... nullable: builtins.bool @property def metadata( @@ -451,18 +566,25 @@ class DataType(google.protobuf.message.Message): def __init__( self, *, - type: global___DataType | None = ..., name: builtins.str = ..., + data_type: global___DataType | None = ..., nullable: builtins.bool = ..., metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["type", b"type"] + self, field_name: typing_extensions.Literal["data_type", b"data_type"] ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "metadata", b"metadata", "name", b"name", "nullable", b"nullable", "type", b"type" + "data_type", + b"data_type", + "metadata", + b"metadata", + "name", + b"name", + "nullable", + b"nullable", ], ) -> None: ... @@ -491,33 +613,33 @@ class DataType(google.protobuf.message.Message): ], ) -> None: ... - class List(google.protobuf.message.Message): + class Array(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - DATATYPE_FIELD_NUMBER: builtins.int + ELEMENT_TYPE_FIELD_NUMBER: builtins.int + CONTAINS_NULL_FIELD_NUMBER: builtins.int TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int - ELEMENT_NULLABLE_FIELD_NUMBER: builtins.int @property - def DataType(self) -> global___DataType: ... + def element_type(self) -> global___DataType: ... + contains_null: builtins.bool type_variation_reference: builtins.int - element_nullable: builtins.bool def __init__( self, *, - DataType: global___DataType | None = ..., + element_type: global___DataType | None = ..., + contains_null: builtins.bool = ..., type_variation_reference: builtins.int = ..., - element_nullable: builtins.bool = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["DataType", b"DataType"] + self, field_name: typing_extensions.Literal["element_type", b"element_type"] ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "DataType", - b"DataType", - "element_nullable", - b"element_nullable", + "contains_null", + b"contains_null", + "element_type", + b"element_type", "type_variation_reference", b"type_variation_reference", ], @@ -526,276 +648,293 @@ class DataType(google.protobuf.message.Message): class Map(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int + KEY_TYPE_FIELD_NUMBER: builtins.int + VALUE_TYPE_FIELD_NUMBER: builtins.int + VALUE_CONTAINS_NULL_FIELD_NUMBER: builtins.int TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int - VALUE_NULLABLE_FIELD_NUMBER: builtins.int @property - def key(self) -> global___DataType: ... + def key_type(self) -> global___DataType: ... @property - def value(self) -> global___DataType: ... + def value_type(self) -> global___DataType: ... + value_contains_null: builtins.bool type_variation_reference: builtins.int - value_nullable: builtins.bool def __init__( self, *, - key: global___DataType | None = ..., - value: global___DataType | None = ..., + key_type: global___DataType | None = ..., + value_type: global___DataType | None = ..., + value_contains_null: builtins.bool = ..., type_variation_reference: builtins.int = ..., - value_nullable: builtins.bool = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + self, + field_name: typing_extensions.Literal[ + "key_type", b"key_type", "value_type", b"value_type" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "key", - b"key", + "key_type", + b"key_type", "type_variation_reference", b"type_variation_reference", - "value", - b"value", - "value_nullable", - b"value_nullable", + "value_contains_null", + b"value_contains_null", + "value_type", + b"value_type", ], ) -> None: ... - BOOL_FIELD_NUMBER: builtins.int - I8_FIELD_NUMBER: builtins.int - I16_FIELD_NUMBER: builtins.int - I32_FIELD_NUMBER: builtins.int - I64_FIELD_NUMBER: builtins.int - FP32_FIELD_NUMBER: builtins.int - FP64_FIELD_NUMBER: builtins.int - STRING_FIELD_NUMBER: builtins.int + NULL_FIELD_NUMBER: builtins.int BINARY_FIELD_NUMBER: builtins.int - TIMESTAMP_FIELD_NUMBER: builtins.int - DATE_FIELD_NUMBER: builtins.int - TIME_FIELD_NUMBER: builtins.int - INTERVAL_YEAR_FIELD_NUMBER: builtins.int - INTERVAL_DAY_FIELD_NUMBER: builtins.int - TIMESTAMP_TZ_FIELD_NUMBER: builtins.int - UUID_FIELD_NUMBER: builtins.int - FIXED_CHAR_FIELD_NUMBER: builtins.int - VARCHAR_FIELD_NUMBER: builtins.int - FIXED_BINARY_FIELD_NUMBER: builtins.int + BOOLEAN_FIELD_NUMBER: builtins.int + BYTE_FIELD_NUMBER: builtins.int + SHORT_FIELD_NUMBER: builtins.int + INTEGER_FIELD_NUMBER: builtins.int + LONG_FIELD_NUMBER: builtins.int + FLOAT_FIELD_NUMBER: builtins.int + DOUBLE_FIELD_NUMBER: builtins.int DECIMAL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + CHAR_FIELD_NUMBER: builtins.int + VAR_CHAR_FIELD_NUMBER: builtins.int + DATE_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + TIMESTAMP_NTZ_FIELD_NUMBER: builtins.int + CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int + YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int + DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int + ARRAY_FIELD_NUMBER: builtins.int STRUCT_FIELD_NUMBER: builtins.int - LIST_FIELD_NUMBER: builtins.int MAP_FIELD_NUMBER: builtins.int + UUID_FIELD_NUMBER: builtins.int + FIXED_BINARY_FIELD_NUMBER: builtins.int USER_DEFINED_TYPE_REFERENCE_FIELD_NUMBER: builtins.int @property - def bool(self) -> global___DataType.Boolean: ... + def null(self) -> global___DataType.NULL: ... @property - def i8(self) -> global___DataType.I8: ... + def binary(self) -> global___DataType.Binary: ... @property - def i16(self) -> global___DataType.I16: ... + def boolean(self) -> global___DataType.Boolean: ... @property - def i32(self) -> global___DataType.I32: ... + def byte(self) -> global___DataType.Byte: + """Numeric types""" @property - def i64(self) -> global___DataType.I64: ... + def short(self) -> global___DataType.Short: ... @property - def fp32(self) -> global___DataType.FP32: ... + def integer(self) -> global___DataType.Integer: ... @property - def fp64(self) -> global___DataType.FP64: ... + def long(self) -> global___DataType.Long: ... @property - def string(self) -> global___DataType.String: ... + def float(self) -> global___DataType.Float: ... @property - def binary(self) -> global___DataType.Binary: ... + def double(self) -> global___DataType.Double: ... @property - def timestamp(self) -> global___DataType.Timestamp: ... + def decimal(self) -> global___DataType.Decimal: ... @property - def date(self) -> global___DataType.Date: ... + def string(self) -> global___DataType.String: + """String types""" @property - def time(self) -> global___DataType.Time: ... + def char(self) -> global___DataType.Char: ... @property - def interval_year(self) -> global___DataType.IntervalYear: ... + def var_char(self) -> global___DataType.VarChar: ... @property - def interval_day(self) -> global___DataType.IntervalDay: ... + def date(self) -> global___DataType.Date: + """Datatime types""" @property - def timestamp_tz(self) -> global___DataType.TimestampTZ: ... + def timestamp(self) -> global___DataType.Timestamp: ... @property - def uuid(self) -> global___DataType.UUID: ... + def timestamp_ntz(self) -> global___DataType.TimestampNTZ: ... @property - def fixed_char(self) -> global___DataType.FixedChar: ... + def calendar_interval(self) -> global___DataType.CalendarInterval: + """Interval types""" @property - def varchar(self) -> global___DataType.VarChar: ... + def year_month_interval(self) -> global___DataType.YearMonthInterval: ... @property - def fixed_binary(self) -> global___DataType.FixedBinary: ... + def day_time_interval(self) -> global___DataType.DayTimeInterval: ... @property - def decimal(self) -> global___DataType.Decimal: ... + def array(self) -> global___DataType.Array: + """Complex types""" @property def struct(self) -> global___DataType.Struct: ... @property - def list(self) -> global___DataType.List: ... - @property def map(self) -> global___DataType.Map: ... + @property + def uuid(self) -> global___DataType.UUID: ... + @property + def fixed_binary(self) -> global___DataType.FixedBinary: ... user_defined_type_reference: builtins.int def __init__( self, *, - bool: global___DataType.Boolean | None = ..., - i8: global___DataType.I8 | None = ..., - i16: global___DataType.I16 | None = ..., - i32: global___DataType.I32 | None = ..., - i64: global___DataType.I64 | None = ..., - fp32: global___DataType.FP32 | None = ..., - fp64: global___DataType.FP64 | None = ..., - string: global___DataType.String | None = ..., + null: global___DataType.NULL | None = ..., binary: global___DataType.Binary | None = ..., - timestamp: global___DataType.Timestamp | None = ..., - date: global___DataType.Date | None = ..., - time: global___DataType.Time | None = ..., - interval_year: global___DataType.IntervalYear | None = ..., - interval_day: global___DataType.IntervalDay | None = ..., - timestamp_tz: global___DataType.TimestampTZ | None = ..., - uuid: global___DataType.UUID | None = ..., - fixed_char: global___DataType.FixedChar | None = ..., - varchar: global___DataType.VarChar | None = ..., - fixed_binary: global___DataType.FixedBinary | None = ..., + boolean: global___DataType.Boolean | None = ..., + byte: global___DataType.Byte | None = ..., + short: global___DataType.Short | None = ..., + integer: global___DataType.Integer | None = ..., + long: global___DataType.Long | None = ..., + float: global___DataType.Float | None = ..., + double: global___DataType.Double | None = ..., decimal: global___DataType.Decimal | None = ..., + string: global___DataType.String | None = ..., + char: global___DataType.Char | None = ..., + var_char: global___DataType.VarChar | None = ..., + date: global___DataType.Date | None = ..., + timestamp: global___DataType.Timestamp | None = ..., + timestamp_ntz: global___DataType.TimestampNTZ | None = ..., + calendar_interval: global___DataType.CalendarInterval | None = ..., + year_month_interval: global___DataType.YearMonthInterval | None = ..., + day_time_interval: global___DataType.DayTimeInterval | None = ..., + array: global___DataType.Array | None = ..., struct: global___DataType.Struct | None = ..., - list: global___DataType.List | None = ..., map: global___DataType.Map | None = ..., + uuid: global___DataType.UUID | None = ..., + fixed_binary: global___DataType.FixedBinary | None = ..., user_defined_type_reference: builtins.int = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ + "array", + b"array", "binary", b"binary", - "bool", - b"bool", + "boolean", + b"boolean", + "byte", + b"byte", + "calendar_interval", + b"calendar_interval", + "char", + b"char", "date", b"date", + "day_time_interval", + b"day_time_interval", "decimal", b"decimal", + "double", + b"double", "fixed_binary", b"fixed_binary", - "fixed_char", - b"fixed_char", - "fp32", - b"fp32", - "fp64", - b"fp64", - "i16", - b"i16", - "i32", - b"i32", - "i64", - b"i64", - "i8", - b"i8", - "interval_day", - b"interval_day", - "interval_year", - b"interval_year", + "float", + b"float", + "integer", + b"integer", "kind", b"kind", - "list", - b"list", + "long", + b"long", "map", b"map", + "null", + b"null", + "short", + b"short", "string", b"string", "struct", b"struct", - "time", - b"time", "timestamp", b"timestamp", - "timestamp_tz", - b"timestamp_tz", + "timestamp_ntz", + b"timestamp_ntz", "user_defined_type_reference", b"user_defined_type_reference", "uuid", b"uuid", - "varchar", - b"varchar", + "var_char", + b"var_char", + "year_month_interval", + b"year_month_interval", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "array", + b"array", "binary", b"binary", - "bool", - b"bool", + "boolean", + b"boolean", + "byte", + b"byte", + "calendar_interval", + b"calendar_interval", + "char", + b"char", "date", b"date", + "day_time_interval", + b"day_time_interval", "decimal", b"decimal", + "double", + b"double", "fixed_binary", b"fixed_binary", - "fixed_char", - b"fixed_char", - "fp32", - b"fp32", - "fp64", - b"fp64", - "i16", - b"i16", - "i32", - b"i32", - "i64", - b"i64", - "i8", - b"i8", - "interval_day", - b"interval_day", - "interval_year", - b"interval_year", + "float", + b"float", + "integer", + b"integer", "kind", b"kind", - "list", - b"list", + "long", + b"long", "map", b"map", + "null", + b"null", + "short", + b"short", "string", b"string", "struct", b"struct", - "time", - b"time", "timestamp", b"timestamp", - "timestamp_tz", - b"timestamp_tz", + "timestamp_ntz", + b"timestamp_ntz", "user_defined_type_reference", b"user_defined_type_reference", "uuid", b"uuid", - "varchar", - b"varchar", + "var_char", + b"var_char", + "year_month_interval", + b"year_month_interval", ], ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["kind", b"kind"] ) -> typing_extensions.Literal[ - "bool", - "i8", - "i16", - "i32", - "i64", - "fp32", - "fp64", - "string", + "null", "binary", - "timestamp", - "date", - "time", - "interval_year", - "interval_day", - "timestamp_tz", - "uuid", - "fixed_char", - "varchar", - "fixed_binary", + "boolean", + "byte", + "short", + "integer", + "long", + "float", + "double", "decimal", + "string", + "char", + "var_char", + "date", + "timestamp", + "timestamp_ntz", + "calendar_interval", + "year_month_interval", + "day_time_interval", + "array", "struct", - "list", "map", + "uuid", + "fixed_binary", "user_defined_type_reference", ] | None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index fa27b609941..ab454c53491 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -144,6 +144,73 @@ class SparkConnectTests(SparkConnectSQLTestCase): schema, ) + # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType + query = """ + SELECT * FROM VALUES + (float(1.0), double(1.0), 1.0, "1", true, NULL), + (float(2.0), double(2.0), 2.0, "2", false, NULL), + (float(3.0), double(3.0), NULL, "3", false, NULL) + AS tab(a, b, c, d, e, f) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + + # test TimestampType, DateType + query = """ + SELECT * FROM VALUES + (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')), + (TIMESTAMP('2019-04-12 15:50:00'), NULL), + (NULL, DATE('2022-02-22')) + AS tab(a, b) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + + # test MapType + query = """ + SELECT * FROM VALUES + (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)), + (MAP('x', 'yz'), MAP('x', NULL), NULL), + (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4)) + AS tab(a, b, c) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + + # test ArrayType + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)), + (ARRAY('x', NULL), NULL, ARRAY(1, 3)), + (NULL, ARRAY(-1, -2, -3), Array()) + AS tab(a, b, c) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + + # test StructType + query = """ + SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES + (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)), + (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)), + (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL) + AS tab(a, b, c, d, e, f, g, h) + """ + # compare the __repr__() to ignore the metadata for now + # the metadata is not supported in Connect for now + self.assertEqual( + self.spark.sql(query).schema.__repr__(), + self.connect.sql(query).schema.__repr__(), + ) + def test_simple_binary_expressions(self): """Test complex expression""" df = self.connect.read.table(self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org