This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f66bcd4fba8 [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull` f66bcd4fba8 is described below commit f66bcd4fba8d0947fd3c7a9c2f9621e78c1fbc0f Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Mon Aug 15 21:29:20 2022 +0800 [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull` ### What changes were proposed in this pull request? ArrayType's parameter `containsNull` means this array can contains null, related to nullable, this is easy to misunderstand in reading logic. In this pr, we refactor the comment about `containsNull` and refactor the code in ArrayType related expression to make the code path have a certain meaning. ### Why are the changes needed? Refactor code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Not need Closes #37453 from AngersZhuuuu/SPARK-40019. Lead-authored-by: Angerszhuuuu <angers....@gmail.com> Co-authored-by: AngersZhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/collectionOperations.scala | 79 ++++++++++++---------- .../expressions/complexTypeExtractors.scala | 7 +- .../org/apache/spark/sql/types/ArrayType.scala | 7 +- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d6a9601f884..f40f5a98232 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -71,6 +71,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]") } } + + protected def leftArrayElementNullable = left.dataType.asInstanceOf[ArrayType].containsNull + protected def rightArrayElementNullable = right.dataType.asInstanceOf[ArrayType].containsNull } @@ -895,7 +898,8 @@ trait ArraySortLike extends ExpectsInputTypes { @transient lazy val elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType - def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + private def resultArrayElementNullable: Boolean = + arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) @@ -930,8 +934,8 @@ trait ArraySortLike extends ExpectsInputTypes { } else { s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" } - val canPerformFastSort = - CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull + val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) && + elementType != BooleanType && !resultArrayElementNullable val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { val javaType = CodeGenerator.javaType(elementType) val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) @@ -1079,6 +1083,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) override def dataType: DataType = child.dataType + private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private[this] var random: RandomIndicesGenerator = _ @@ -1118,7 +1124,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElements, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName, - i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull) + i, s"$indices[$i]", resultArrayElementNullable) s""" |int $numElements = $childName.numElements(); @@ -1162,6 +1168,8 @@ case class Reverse(child: Expression) override def dataType: DataType = child.dataType + private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull + override def nullSafeEval(input: Any): Any = doReverse(input) @transient private lazy val doReverse: Any => Any = dataType match { @@ -1196,7 +1204,7 @@ case class Reverse(child: Expression) val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElements, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( - arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull) + arrayData, elementType, childName, i, j, resultArrayElementNullable) s""" |final int $numElements = $childName.numElements(); @@ -1347,8 +1355,7 @@ case class ArraysOverlap(left: Expression, right: Expression) } override def nullable: Boolean = { - left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || - right.dataType.asInstanceOf[ArrayType].containsNull + left.nullable || right.nullable || leftArrayElementNullable || rightArrayElementNullable } override def nullSafeEval(a1: Any, a2: Any): Any = { @@ -1560,6 +1567,8 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def dataType: DataType = x.dataType + private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) override def first: Expression = x @@ -1632,7 +1641,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) val allocation = CodeGenerator.createArrayData( values, elementType, resLength, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray, - i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull) + i, s"$i + $startIdx", resultArrayElementNullable) s""" |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { @@ -2103,7 +2112,8 @@ case class ElementAt( @transient private lazy val mapValueContainsNull = left.dataType.asInstanceOf[MapType].valueContainsNull - @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + @transient private lazy val arrayElementNullable = + left.dataType.asInstanceOf[ArrayType].containsNull @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) @@ -2189,7 +2199,7 @@ case class ElementAt( } else { array.numElements() + index } - if (arrayContainsNull && array.isNullAt(idx)) { + if (arrayElementNullable && array.isNullAt(idx)) { null } else { array.get(idx, dataType) @@ -2205,7 +2215,7 @@ case class ElementAt( case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (arrayContainsNull) { + val nullCheck = if (arrayElementNullable) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; @@ -2353,6 +2363,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } + private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2484,8 +2496,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElemName, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( - arrayData, elementType, s"args[$y]", counter, z, - dataType.asInstanceOf[ArrayType].containsNull) + arrayData, elementType, s"args[$y]", counter, z, resultArrayElementNullable) val concat = ctx.freshName("concat") val concatDef = @@ -2535,6 +2546,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran @transient override lazy val dataType: DataType = childDataType.elementType + private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { @@ -2604,8 +2617,7 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran val allocation = CodeGenerator.createArrayData( tempArrayDataName, elementType, numElemName, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( - tempArrayDataName, elementType, arr, counter, l, - dataType.asInstanceOf[ArrayType].containsNull) + tempArrayDataName, elementType, arr, counter, l, resultArrayElementNullable) s""" |$numElemCode @@ -3486,6 +3498,8 @@ trait ArraySetLike { @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(et) + protected def resultArrayElementNullable = dt.asInstanceOf[ArrayType].containsNull + protected def genGetValue(array: String, i: String): String = CodeGenerator.getValue(array, et, i) @@ -3521,7 +3535,7 @@ trait ArraySetLike { body: String, value: String, nullElementIndex: String): String = { - if (dt.asInstanceOf[ArrayType].containsNull) { + if (resultArrayElementNullable) { s""" |$body |if ($nullElementIndex >= 0) { @@ -3662,7 +3676,7 @@ case class ArrayDistinct(child: Expression) val arrayBuilderClass = s"$arrayBuilder$$of$ptName" // Only need to track null element index when array's element is nullable. - val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + val declareNullTrackVariables = if (resultArrayElementNullable) { s""" |int $nullElementIndex = -1; """.stripMargin @@ -3692,8 +3706,8 @@ case class ArrayDistinct(child: Expression) """.stripMargin) val processArray = SQLOpenHashSet.withNullCheckCode( - dataType.asInstanceOf[ArrayType].containsNull, - dataType.asInstanceOf[ArrayType].containsNull, + resultArrayElementNullable, + resultArrayElementNullable, array, i, hashSet, withNaNCheckCodeGenerator, s""" |$nullElementIndex = $size; @@ -3880,8 +3894,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi """.stripMargin) val processArray = SQLOpenHashSet.withNullCheckCode( - dataType.asInstanceOf[ArrayType].containsNull, - dataType.asInstanceOf[ArrayType].containsNull, + resultArrayElementNullable, + resultArrayElementNullable, array, i, hashSet, withNaNCheckCodeGenerator, s""" |$nullElementIndex = $size; @@ -3890,7 +3904,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi """.stripMargin) // Only need to track null element index when result array's element is nullable. - val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + val declareNullTrackVariables = if (resultArrayElementNullable) { s""" |int $nullElementIndex = -1; """.stripMargin @@ -3985,9 +3999,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina private lazy val internalDataType: DataType = { dataTypeCheck - ArrayType(elementType, - left.dataType.asInstanceOf[ArrayType].containsNull && - right.dataType.asInstanceOf[ArrayType].containsNull) + ArrayType(elementType, leftArrayElementNullable && rightArrayElementNullable) } override def dataType: DataType = internalDataType @@ -4122,8 +4134,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina (valueNaN: String) => "") val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode( - right.dataType.asInstanceOf[ArrayType].containsNull, - left.dataType.asInstanceOf[ArrayType].containsNull, + rightArrayElementNullable, leftArrayElementNullable, array2, i, hashSet, withArray2NaNCheckCodeGenerator, "") val body = @@ -4151,8 +4162,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina """.stripMargin) val processArray1 = SQLOpenHashSet.withNullCheckCode( - left.dataType.asInstanceOf[ArrayType].containsNull, - right.dataType.asInstanceOf[ArrayType].containsNull, + leftArrayElementNullable, rightArrayElementNullable, array1, i, hashSetResult, withArray1NaNCheckCodeGenerator, s""" |if ($hashSet.containsNull()) { @@ -4163,7 +4173,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina """.stripMargin) // Only need to track null element index when result array's element is nullable. - val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + val declareNullTrackVariables = if (resultArrayElementNullable) { s""" |int $nullElementIndex = -1; """.stripMargin @@ -4340,8 +4350,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL (valueNaN: Any) => "") val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode( - right.dataType.asInstanceOf[ArrayType].containsNull, - left.dataType.asInstanceOf[ArrayType].containsNull, + rightArrayElementNullable, leftArrayElementNullable, array2, i, hashSet, withArray2NaNCheckCodeGenerator, "") val body = @@ -4366,8 +4375,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL """.stripMargin) val processArray1 = SQLOpenHashSet.withNullCheckCode( - left.dataType.asInstanceOf[ArrayType].containsNull, - left.dataType.asInstanceOf[ArrayType].containsNull, + leftArrayElementNullable, + leftArrayElementNullable, array1, i, hashSet, withArray1NaNCheckCodeGenerator, s""" |$nullElementIndex = $size; @@ -4376,7 +4385,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL """.stripMargin) // Only need to track null element index when array1's element is nullable. - val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val declareNullTrackVariables = if (leftArrayElementNullable) { s""" |int $nullElementIndex = -1; """.stripMargin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index b6cbb1d0005..7b99b9d1082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -282,7 +282,8 @@ case class GetArrayItem( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") - val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { + val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull + val nullCheck = if (childArrayElementNullable) { s"""else if ($eval1.isNullAt($index)) { ${ev.isNull} = true; } @@ -333,7 +334,7 @@ trait GetArrayItemUtil { ordinal: Expression, failOnError: Boolean, nullability: (Seq[Expression], Int) => Boolean): Boolean = { - val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull + val arrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull if (ordinal.foldable && !ordinal.nullable) { val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() child match { @@ -345,7 +346,7 @@ trait GetArrayItemUtil { true } } else { - if (failOnError) arrayContainsNull else true + if (failOnError) arrayElementNullable else true } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index b5708bae923..e139823b2bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -53,11 +53,12 @@ object ArrayType extends AbstractDataType { * Please use `DataTypes.createArrayType()` to create a specific instance. * * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * `containsNull: Boolean`. + * The field of `elementType` is used to specify the type of array elements. + * The field of `containsNull` is used to specify if the array can have `null` values. * * @param elementType The data type of values. - * @param containsNull Indicates if values have `null` values + * @param containsNull Indicates if the array can have `null` values * * @since 1.3.0 */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org