This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 55b5ff6f45fd [SPARK-47669][SQL][CONNECT][PYTHON] Add `Column.try_cast` 55b5ff6f45fd is described below commit 55b5ff6f45fde0048cfd4d8d1a41d6e7f65fd121 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Apr 3 08:15:08 2024 +0800 [SPARK-47669][SQL][CONNECT][PYTHON] Add `Column.try_cast` ### What changes were proposed in this pull request? Add `try_cast` function in Column APIs ### Why are the changes needed? for functionality parity ### Does this PR introduce _any_ user-facing change? yes ``` >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame( ... [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"]) >>> df.select(df.name.try_cast("double")).show() +-----+ | name| +-----+ |123.0| | NULL| | NULL| +-----+ ``` ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45796 from zhengruifeng/connect_try_cast. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/scala/org/apache/spark/sql/Column.scala | 35 ++++++++ .../apache/spark/sql/PlanGenerationTestSuite.scala | 4 + .../main/protobuf/spark/connect/expressions.proto | 10 +++ .../explain-results/column_try_cast.explain | 2 + .../query-tests/queries/column_try_cast.json | 29 +++++++ .../query-tests/queries/column_try_cast.proto.bin | Bin 0 -> 173 bytes .../sql/connect/planner/SparkConnectPlanner.scala | 18 ++-- .../docs/source/reference/pyspark.sql/column.rst | 1 + python/pyspark/sql/column.py | 62 +++++++++++++ python/pyspark/sql/connect/column.py | 17 ++++ python/pyspark/sql/connect/expressions.py | 14 +++ .../pyspark/sql/connect/proto/expressions_pb2.py | 96 +++++++++++---------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 28 ++++++ .../main/scala/org/apache/spark/sql/Column.scala | 37 ++++++++ 14 files changed, 301 insertions(+), 52 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index 4cb99541ccf0..dec699f4f1a8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -1090,6 +1090,41 @@ class Column private[sql] (@DeveloperApi val expr: proto.Expression) extends Log */ def cast(to: String): Column = cast(DataTypeParser.parseDataType(to)) + /** + * Casts the column to a different data type and the result is null on failure. + * {{{ + * // Casts colA to IntegerType. + * import org.apache.spark.sql.types.IntegerType + * df.select(df("colA").try_cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").try_cast("int")) + * }}} + * + * @group expr_ops + * @since 4.0.0 + */ + def try_cast(to: DataType): Column = Column { builder => + builder.getCastBuilder + .setExpr(expr) + .setType(DataTypeProtoConverter.toConnectProtoType(to)) + .setEvalMode(proto.Expression.Cast.EvalMode.EVAL_MODE_TRY) + } + + /** + * Casts the column to a different data type and the result is null on failure. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").try_cast("int")) + * }}} + * + * @group expr_ops + * @since 4.0.0 + */ + def try_cast(to: String): Column = { + try_cast(DataTypeParser.parseDataType(to)) + } + /** * Returns a sort expression based on the descending order of the column. * {{{ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 46789057ed3c..5fde8b04735b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -871,6 +871,10 @@ class PlanGenerationTestSuite fn.col("a").cast("long") } + columnTest("try_cast") { + fn.col("a").try_cast("long") + } + orderColumnTest("desc") { fn.col("b").desc } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index c3333636bf68..726ae5dd1c21 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -147,6 +147,16 @@ message Expression { // If this is set, Server will use Catalyst parser to parse this string to DataType. string type_str = 3; } + + // (Optional) The expression evaluation mode. + EvalMode eval_mode = 4; + + enum EvalMode { + EVAL_MODE_UNSPECIFIED = 0; + EVAL_MODE_LEGACY = 1; + EVAL_MODE_ANSI = 2; + EVAL_MODE_TRY = 3; + } } message Literal { diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/column_try_cast.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/column_try_cast.explain new file mode 100644 index 000000000000..b2c5a6e35af8 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/column_try_cast.explain @@ -0,0 +1,2 @@ +Project [try_cast(a#0 as bigint) AS a#0L] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.json b/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.json new file mode 100644 index 000000000000..be6d38a429c7 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "cast": { + "expr": { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, + "type": { + "long": { + } + }, + "evalMode": "EVAL_MODE_TRY" + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.proto.bin new file mode 100644 index 000000000000..ca94e6c2f94d Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/column_try_cast.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 313c17c25473..1894ab984490 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2097,11 +2097,19 @@ class SparkConnectPlanner( } private def transformCast(cast: proto.Expression.Cast): Expression = { - cast.getCastToTypeCase match { - case proto.Expression.Cast.CastToTypeCase.TYPE => - Cast(transformExpression(cast.getExpr), transformDataType(cast.getType)) - case _ => - Cast(transformExpression(cast.getExpr), parser.parseDataType(cast.getTypeStr)) + val dataType = cast.getCastToTypeCase match { + case proto.Expression.Cast.CastToTypeCase.TYPE => transformDataType(cast.getType) + case _ => parser.parseDataType(cast.getTypeStr) + } + val mode = cast.getEvalMode match { + case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY => Some(EvalMode.LEGACY) + case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI => Some(EvalMode.ANSI) + case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY => Some(EvalMode.TRY) + case _ => None + } + mode match { + case Some(m) => Cast(transformExpression(cast.getExpr), dataType, None, m) + case _ => Cast(transformExpression(cast.getExpr), dataType) } } diff --git a/python/docs/source/reference/pyspark.sql/column.rst b/python/docs/source/reference/pyspark.sql/column.rst index 08052bcc4683..91211d859a19 100644 --- a/python/docs/source/reference/pyspark.sql/column.rst +++ b/python/docs/source/reference/pyspark.sql/column.rst @@ -57,5 +57,6 @@ Column Column.rlike Column.startswith Column.substr + Column.try_cast Column.when Column.withField diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 4a7213593703..bf6192f8c58d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -1282,6 +1282,68 @@ class Column: ) return Column(jc) + def try_cast(self, dataType: Union[DataType, str]) -> "Column": + """ + This is a special version of `cast` that performs the same operation, but returns a NULL + value instead of raising an error if the invoke method throws exception. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + dataType : :class:`DataType` or str + a DataType or Python string literal with a DDL-formatted string + to use when parsing the column to the same type. + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is cast into new type. + + Examples + -------- + Example 1: Cast with a Datatype + + >>> from pyspark.sql.types import LongType + >>> df = spark.createDataFrame( + ... [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"]) + >>> df.select(df.name.try_cast(LongType())).show() + +----+ + |name| + +----+ + | 123| + |NULL| + |NULL| + +----+ + + Example 2: Cast with a DDL string + + >>> df = spark.createDataFrame( + ... [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"]) + >>> df.select(df.name.try_cast("double")).show() + +-----+ + | name| + +-----+ + |123.0| + | NULL| + | NULL| + +-----+ + """ + if isinstance(dataType, str): + jc = self._jc.try_cast(dataType) + elif isinstance(dataType, DataType): + from pyspark.sql import SparkSession + + spark = SparkSession._getActiveSessionOrCreate() + jdt = spark._jsparkSession.parseDataType(dataType.json()) + jc = self._jc.try_cast(jdt) + else: + raise PySparkTypeError( + error_class="NOT_DATATYPE_OR_STR", + message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__}, + ) + return Column(jc) + def astype(self, dataType: Union[DataType, str]) -> "Column": """ :func:`astype` is an alias for :func:`cast`. diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 052151d5417e..719d592924ad 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -331,6 +331,23 @@ class Column: astype = cast + def try_cast(self, dataType: Union[DataType, str]) -> "Column": + if isinstance(dataType, (DataType, str)): + return Column( + CastExpression( + expr=self._expr, + data_type=dataType, + eval_mode="try", + ) + ) + else: + raise PySparkTypeError( + error_class="NOT_DATATYPE_OR_STR", + message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__}, + ) + + try_cast.__doc__ = PySparkColumn.try_cast.__doc__ + def __repr__(self) -> str: return "Column<'%s'>" % self._expr.__repr__() diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 4bc8a0a034e8..b1735f65f520 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -837,10 +837,15 @@ class CastExpression(Expression): self, expr: Expression, data_type: Union[DataType, str], + eval_mode: Optional[str] = None, ) -> None: super().__init__() self._expr = expr self._data_type = data_type + if eval_mode is not None: + assert isinstance(eval_mode, str) + assert eval_mode in ["legacy", "ansi", "try"] + self._eval_mode = eval_mode def to_plan(self, session: "SparkConnectClient") -> proto.Expression: fun = proto.Expression() @@ -849,6 +854,15 @@ class CastExpression(Expression): fun.cast.type_str = self._data_type else: fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type)) + + if self._eval_mode is not None: + if self._eval_mode == "legacy": + fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY + elif self._eval_mode == "ansi": + fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI + elif self._eval_mode == "try": + fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_TRY + return fun def __repr__(self) -> str: diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index fb3ebf30d300..9a4b597b13f7 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xb4-\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xde.\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,7 +44,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 5917 + _EXPRESSION._serialized_end = 6087 _EXPRESSION_WINDOW._serialized_start = 1645 _EXPRESSION_WINDOW._serialized_end = 2428 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935 @@ -60,49 +60,51 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771 _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856 _EXPRESSION_CAST._serialized_start = 2859 - _EXPRESSION_CAST._serialized_end = 3004 - _EXPRESSION_LITERAL._serialized_start = 3007 - _EXPRESSION_LITERAL._serialized_end = 4570 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3842 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3959 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3961 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4059 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 4062 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 4192 - _EXPRESSION_LITERAL_MAP._serialized_start = 4195 - _EXPRESSION_LITERAL_MAP._serialized_end = 4422 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4573 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4759 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4762 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4966 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4968 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5018 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5020 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5144 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5146 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5232 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5235 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5367 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5370 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5557 - _EXPRESSION_ALIAS._serialized_start = 5559 - _EXPRESSION_ALIAS._serialized_end = 5679 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5682 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5840 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5842 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5904 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5920 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6284 - _PYTHONUDF._serialized_start = 6287 - _PYTHONUDF._serialized_end = 6442 - _SCALARSCALAUDF._serialized_start = 6445 - _SCALARSCALAUDF._serialized_end = 6629 - _JAVAUDF._serialized_start = 6632 - _JAVAUDF._serialized_end = 6781 - _CALLFUNCTION._serialized_start = 6783 - _CALLFUNCTION._serialized_end = 6891 - _NAMEDARGUMENTEXPRESSION._serialized_start = 6893 - _NAMEDARGUMENTEXPRESSION._serialized_end = 6985 + _EXPRESSION_CAST._serialized_end = 3174 + _EXPRESSION_CAST_EVALMODE._serialized_start = 3060 + _EXPRESSION_CAST_EVALMODE._serialized_end = 3158 + _EXPRESSION_LITERAL._serialized_start = 3177 + _EXPRESSION_LITERAL._serialized_end = 4740 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4012 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4129 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4131 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4229 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 4232 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 4362 + _EXPRESSION_LITERAL_MAP._serialized_start = 4365 + _EXPRESSION_LITERAL_MAP._serialized_end = 4592 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 4595 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 4724 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4743 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4929 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4932 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5136 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5138 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5188 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5190 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5314 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5316 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5402 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5405 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5537 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5540 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5727 + _EXPRESSION_ALIAS._serialized_start = 5729 + _EXPRESSION_ALIAS._serialized_end = 5849 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5852 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6010 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6012 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6074 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6090 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6454 + _PYTHONUDF._serialized_start = 6457 + _PYTHONUDF._serialized_end = 6612 + _SCALARSCALAUDF._serialized_start = 6615 + _SCALARSCALAUDF._serialized_end = 6799 + _JAVAUDF._serialized_start = 6802 + _JAVAUDF._serialized_end = 6951 + _CALLFUNCTION._serialized_start = 6953 + _CALLFUNCTION._serialized_end = 7061 + _NAMEDARGUMENTEXPRESSION._serialized_start = 7063 + _NAMEDARGUMENTEXPRESSION._serialized_end = 7155 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index e397880a73e4..183a839da920 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -309,9 +309,32 @@ class Expression(google.protobuf.message.Message): class Cast(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class _EvalMode: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _EvalModeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ + Expression.Cast._EvalMode.ValueType + ], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + EVAL_MODE_UNSPECIFIED: Expression.Cast._EvalMode.ValueType # 0 + EVAL_MODE_LEGACY: Expression.Cast._EvalMode.ValueType # 1 + EVAL_MODE_ANSI: Expression.Cast._EvalMode.ValueType # 2 + EVAL_MODE_TRY: Expression.Cast._EvalMode.ValueType # 3 + + class EvalMode(_EvalMode, metaclass=_EvalModeEnumTypeWrapper): ... + EVAL_MODE_UNSPECIFIED: Expression.Cast.EvalMode.ValueType # 0 + EVAL_MODE_LEGACY: Expression.Cast.EvalMode.ValueType # 1 + EVAL_MODE_ANSI: Expression.Cast.EvalMode.ValueType # 2 + EVAL_MODE_TRY: Expression.Cast.EvalMode.ValueType # 3 + EXPR_FIELD_NUMBER: builtins.int TYPE_FIELD_NUMBER: builtins.int TYPE_STR_FIELD_NUMBER: builtins.int + EVAL_MODE_FIELD_NUMBER: builtins.int @property def expr(self) -> global___Expression: """(Required) the expression to be casted.""" @@ -319,12 +342,15 @@ class Expression(google.protobuf.message.Message): def type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... type_str: builtins.str """If this is set, Server will use Catalyst parser to parse this string to DataType.""" + eval_mode: global___Expression.Cast.EvalMode.ValueType + """(Optional) The expression evaluation mode.""" def __init__( self, *, expr: global___Expression | None = ..., type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., type_str: builtins.str = ..., + eval_mode: global___Expression.Cast.EvalMode.ValueType = ..., ) -> None: ... def HasField( self, @@ -344,6 +370,8 @@ class Expression(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "cast_to_type", b"cast_to_type", + "eval_mode", + b"eval_mode", "expr", b"expr", "type", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 39d720c933a8..fdd315a44f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1222,6 +1222,43 @@ class Column(val expr: Expression) extends Logging { */ def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) + /** + * Casts the column to a different data type and the result is null on failure. + * {{{ + * // Casts colA to IntegerType. + * import org.apache.spark.sql.types.IntegerType + * df.select(df("colA").try_cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").try_cast("int")) + * }}} + * + * @group expr_ops + * @since 4.0.0 + */ + def try_cast(to: DataType): Column = withExpr { + val cast = Cast( + child = expr, + dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(to), + evalMode = EvalMode.TRY) + cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) + cast + } + + /** + * Casts the column to a different data type and the result is null on failure. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").try_cast("int")) + * }}} + * + * @group expr_ops + * @since 4.0.0 + */ + def try_cast(to: String): Column = { + try_cast(CatalystSqlParser.parseDataType(to)) + } + /** * Returns a sort expression based on the descending order of the column. * {{{ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org