This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new cf5956f [SPARK-30899][SQL] CreateArray/CreateMap's data type should not depend on SQLConf.get cf5956f is described below commit cf5956f058607dac1866fddbee495f0d46c19c05 Author: iRakson <raksonrak...@gmail.com> AuthorDate: Fri Mar 6 16:45:06 2020 +0800 [SPARK-30899][SQL] CreateArray/CreateMap's data type should not depend on SQLConf.get ### What changes were proposed in this pull request? Introduced a new parameter `emptyCollection` for `CreateMap` and `CreateArray` functiion to remove dependency on SQLConf.get. ### Why are the changes needed? This allows to avoid the issue when the configuration change between different phases of planning, and this can silently break a query plan which can lead to crashes or data corruption. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing UTs. Closes #27657 from iRakson/SPARK-30899. Authored-by: iRakson <raksonrak...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit cba17e07e9f15673f274de1728f6137d600026e1) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++--- .../catalyst/expressions/complexTypeCreator.scala | 35 +++++++++++++++++++--- .../expressions/complexTypeExtractors.scala | 4 +-- .../sql/catalyst/optimizer/ComplexTypes.scala | 8 ++--- .../optimizer/NormalizeFloatingNumbers.scala | 8 ++--- 5 files changed, 45 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f416e8e..0a0bef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -553,10 +553,10 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => + case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) + case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) case None => a } @@ -592,7 +592,7 @@ object TypeCoercion { case None => m } - case m @ CreateMap(children) if m.keys.length == m.values.length && + case m @ CreateMap(children, _) if m.keys.length == m.values.length && (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => val keyTypes = m.keys.map(_.dataType) val newKeys = findWiderCommonType(keyTypes) match { @@ -606,7 +606,7 @@ object TypeCoercion { case None => m.values } - CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) + m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4bd85d3..6c31511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -37,16 +37,23 @@ import org.apache.spark.unsafe.types.UTF8String > SELECT _FUNC_(1, 2, 3); [1,2,3] """) -case class CreateArray(children: Seq[Expression]) extends Expression { +case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) + extends Expression { + + def this(children: Seq[Expression]) = { + this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) + } override def foldable: Boolean = children.forall(_.foldable) + override def stringArgs: Iterator[Any] = super.stringArgs.take(1) + override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } private val defaultElementType: DataType = { - if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { + if (useStringTypeWhenEmpty) { StringType } else { NullType @@ -79,6 +86,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def prettyName: String = "array" } +object CreateArray { + def apply(children: Seq[Expression]): CreateArray = { + new CreateArray(children) + } +} + private [sql] object GenArrayData { /** * Return Java code pieces based on DataType and array size to allocate ArrayData class @@ -141,12 +154,18 @@ private [sql] object GenArrayData { > SELECT _FUNC_(1.0, '2', 3.0, '4'); {1.0:"2",3.0:"4"} """) -case class CreateMap(children: Seq[Expression]) extends Expression { +case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) + extends Expression { + + def this(children: Seq[Expression]) = { + this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) + } + lazy val keys = children.indices.filter(_ % 2 == 0).map(children) lazy val values = children.indices.filter(_ % 2 != 0).map(children) private val defaultElementType: DataType = { - if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { + if (useStringTypeWhenEmpty) { StringType } else { NullType @@ -155,6 +174,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) + override def stringArgs: Iterator[Any] = super.stringArgs.take(1) + override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( @@ -215,6 +236,12 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +object CreateMap { + def apply(children: Seq[Expression]): CreateMap = { + new CreateMap(children) + } +} + /** * Returns a catalyst Map containing the two arrays in children expressions as keys and values. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index e9d60ed..9c600c9d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -275,9 +275,9 @@ trait GetArrayItemUtil { if (ordinal.foldable && !ordinal.nullable) { val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() child match { - case CreateArray(ar) if intOrdinal < ar.length => + case CreateArray(ar, _) if intOrdinal < ar.length => ar(intOrdinal).nullable - case GetArrayStructFields(CreateArray(elements), field, _, _, _) + case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) if intOrdinal < elements.length => elements(intOrdinal).nullable || field.nullable case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 28dc8e9..f79dabf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -41,14 +41,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { createNamedStruct.valExprs(ordinal) // Remove redundant array indexing. - case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member // of the array. Pushing down the operation this way may open other optimizations // opportunities (i.e. struct(...,x,...).x) - CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) + CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty) // Remove redundant map lookup. - case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => + case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) => // Instead of creating the array and then selecting one row, remove array creation // altogether. if (idx >= 0 && idx < elems.size) { @@ -58,7 +58,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) + case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index ea01d9e..5f94af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -114,11 +114,11 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) - case CreateArray(children) => - CreateArray(children.map(normalize)) + case CreateArray(children, useStringTypeWhenEmpty) => + CreateArray(children.map(normalize), useStringTypeWhenEmpty) - case CreateMap(children) => - CreateMap(children.map(normalize)) + case CreateMap(children, useStringTypeWhenEmpty) => + CreateMap(children.map(normalize), useStringTypeWhenEmpty) case _ if expr.dataType == FloatType || expr.dataType == DoubleType => KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org