This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new cf5956f  [SPARK-30899][SQL] CreateArray/CreateMap's data type should 
not depend on SQLConf.get
cf5956f is described below

commit cf5956f058607dac1866fddbee495f0d46c19c05
Author: iRakson <raksonrak...@gmail.com>
AuthorDate: Fri Mar 6 16:45:06 2020 +0800

    [SPARK-30899][SQL] CreateArray/CreateMap's data type should not depend on 
SQLConf.get
    
    ### What changes were proposed in this pull request?
    Introduced a new parameter `emptyCollection` for `CreateMap` and 
`CreateArray` functiion to remove dependency on SQLConf.get.
    
    ### Why are the changes needed?
    This allows to avoid the issue when the configuration change between 
different phases of planning, and this can silently break a query plan which 
can lead to crashes or data corruption.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    Existing UTs.
    
    Closes #27657 from iRakson/SPARK-30899.
    
    Authored-by: iRakson <raksonrak...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit cba17e07e9f15673f274de1728f6137d600026e1)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  8 ++---
 .../catalyst/expressions/complexTypeCreator.scala  | 35 +++++++++++++++++++---
 .../expressions/complexTypeExtractors.scala        |  4 +--
 .../sql/catalyst/optimizer/ComplexTypes.scala      |  8 ++---
 .../optimizer/NormalizeFloatingNumbers.scala       |  8 ++---
 5 files changed, 45 insertions(+), 18 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index f416e8e..0a0bef6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -553,10 +553,10 @@ object TypeCoercion {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a @ CreateArray(children) if 
!haveSameType(children.map(_.dataType)) =>
+      case a @ CreateArray(children, _) if 
!haveSameType(children.map(_.dataType)) =>
         val types = children.map(_.dataType)
         findWiderCommonType(types) match {
-          case Some(finalDataType) => 
CreateArray(children.map(castIfNotSameType(_, finalDataType)))
+          case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, 
finalDataType)))
           case None => a
         }
 
@@ -592,7 +592,7 @@ object TypeCoercion {
           case None => m
         }
 
-      case m @ CreateMap(children) if m.keys.length == m.values.length &&
+      case m @ CreateMap(children, _) if m.keys.length == m.values.length &&
           (!haveSameType(m.keys.map(_.dataType)) || 
!haveSameType(m.values.map(_.dataType))) =>
         val keyTypes = m.keys.map(_.dataType)
         val newKeys = findWiderCommonType(keyTypes) match {
@@ -606,7 +606,7 @@ object TypeCoercion {
           case None => m.values
         }
 
-        CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
+        m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
 
       // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent 
overflows.
       case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4bd85d3..6c31511 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -37,16 +37,23 @@ import org.apache.spark.unsafe.types.UTF8String
       > SELECT _FUNC_(1, 2, 3);
        [1,2,3]
   """)
-case class CreateArray(children: Seq[Expression]) extends Expression {
+case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: 
Boolean)
+  extends Expression {
+
+  def this(children: Seq[Expression]) = {
+    this(children, 
SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
+  }
 
   override def foldable: Boolean = children.forall(_.foldable)
 
+  override def stringArgs: Iterator[Any] = super.stringArgs.take(1)
+
   override def checkInputDataTypes(): TypeCheckResult = {
     TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function 
$prettyName")
   }
 
   private val defaultElementType: DataType = {
-    if 
(SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) 
{
+    if (useStringTypeWhenEmpty) {
       StringType
     } else {
       NullType
@@ -79,6 +86,12 @@ case class CreateArray(children: Seq[Expression]) extends 
Expression {
   override def prettyName: String = "array"
 }
 
+object CreateArray {
+  def apply(children: Seq[Expression]): CreateArray = {
+    new CreateArray(children)
+  }
+}
+
 private [sql] object GenArrayData {
   /**
    * Return Java code pieces based on DataType and array size to allocate 
ArrayData class
@@ -141,12 +154,18 @@ private [sql] object GenArrayData {
       > SELECT _FUNC_(1.0, '2', 3.0, '4');
        {1.0:"2",3.0:"4"}
   """)
-case class CreateMap(children: Seq[Expression]) extends Expression {
+case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: 
Boolean)
+  extends Expression {
+
+  def this(children: Seq[Expression]) = {
+    this(children, 
SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
+  }
+
   lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
   lazy val values = children.indices.filter(_ % 2 != 0).map(children)
 
   private val defaultElementType: DataType = {
-    if 
(SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) 
{
+    if (useStringTypeWhenEmpty) {
       StringType
     } else {
       NullType
@@ -155,6 +174,8 @@ case class CreateMap(children: Seq[Expression]) extends 
Expression {
 
   override def foldable: Boolean = children.forall(_.foldable)
 
+  override def stringArgs: Iterator[Any] = super.stringArgs.take(1)
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.size % 2 != 0) {
       TypeCheckResult.TypeCheckFailure(
@@ -215,6 +236,12 @@ case class CreateMap(children: Seq[Expression]) extends 
Expression {
   override def prettyName: String = "map"
 }
 
+object CreateMap {
+  def apply(children: Seq[Expression]): CreateMap = {
+    new CreateMap(children)
+  }
+}
+
 /**
  * Returns a catalyst Map containing the two arrays in children expressions as 
keys and values.
  */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index e9d60ed..9c600c9d3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -275,9 +275,9 @@ trait GetArrayItemUtil {
     if (ordinal.foldable && !ordinal.nullable) {
       val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
       child match {
-        case CreateArray(ar) if intOrdinal < ar.length =>
+        case CreateArray(ar, _) if intOrdinal < ar.length =>
           ar(intOrdinal).nullable
-        case GetArrayStructFields(CreateArray(elements), field, _, _, _)
+        case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
           if intOrdinal < elements.length =>
           elements(intOrdinal).nullable || field.nullable
         case _ =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index 28dc8e9..f79dabf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -41,14 +41,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
         createNamedStruct.valExprs(ordinal)
 
       // Remove redundant array indexing.
-      case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
+      case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), 
field, ordinal, _, _) =>
         // Instead of selecting the field on the entire array, select it from 
each member
         // of the array. Pushing down the operation this way may open other 
optimizations
         // opportunities (i.e. struct(...,x,...).x)
-        CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))))
+        CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), 
useStringTypeWhenEmpty)
 
       // Remove redundant map lookup.
-      case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) =>
+      case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
         // Instead of creating the array and then selecting one row, remove 
array creation
         // altogether.
         if (idx >= 0 && idx < elems.size) {
@@ -58,7 +58,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
           // out of bounds, mimic the runtime behavior and return null
           Literal(null, ga.dataType)
         }
-      case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems)
+      case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems)
     }
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index ea01d9e..5f94af5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -114,11 +114,11 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] 
{
     case CreateNamedStruct(children) =>
       CreateNamedStruct(children.map(normalize))
 
-    case CreateArray(children) =>
-      CreateArray(children.map(normalize))
+    case CreateArray(children, useStringTypeWhenEmpty) =>
+      CreateArray(children.map(normalize), useStringTypeWhenEmpty)
 
-    case CreateMap(children) =>
-      CreateMap(children.map(normalize))
+    case CreateMap(children, useStringTypeWhenEmpty) =>
+      CreateMap(children.map(normalize), useStringTypeWhenEmpty)
 
     case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
       KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to