Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21687#discussion_r200801880
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 ---
    @@ -695,6 +695,56 @@ abstract class TernaryExpression extends Expression {
       }
     }
     
    +/**
    + * A trait resolving nullable, containsNull, valueContainsNull flags of 
the output date type.
    + * This logic is usually utilized by expressions combining data from 
multiple child expressions
    + * of non-primitive types (e.g. [[CaseWhen]]).
    + */
    +trait NonPrimitiveTypeMergingExpression extends Expression
    +{
    +  /**
    +   * A collection of data types used for resolution the output type of the 
expression. By default,
    +   * data types of all child expressions. The collection must not be empty.
    +   */
    +  @transient
    +  lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
    +
    +  /**
    +   * A method determining whether the input types are equal ignoring 
nullable, containsNull and
    +   * valueContainsNull flags and thus convenient for resolution of the 
final data type.
    +   */
    +  def areInputTypesForMergingEqual: Boolean = {
    +    inputTypesForMerging.lengthCompare(1) <= 0 || 
inputTypesForMerging.sliding(2, 1).forall {
    +      case Seq(dt1, dt2) => dt1.sameType(dt2)
    +    }
    +  }
    +
    +  private def mergeTwoDataTypes(dt1: DataType, dt2: DataType): DataType = 
(dt1, dt2) match {
    +    case (t1, t2) if t1 == t2 => t1
    +    case (ArrayType(et1, cn1), ArrayType(et2, cn2)) =>
    +      ArrayType(mergeTwoDataTypes(et1, et2), cn1 || cn2)
    +    case (MapType(kt1, vt1, vcn1), MapType(kt2, vt2, vcn2)) =>
    +      MapType(mergeTwoDataTypes(kt1, kt2), mergeTwoDataTypes(vt1, vt2), 
vcn1 || vcn2)
    +    case (StructType(fields1), StructType(fields2)) =>
    +      val newFields = fields1.zip(fields2).map {
    --- End diff --
    
    we should always put a length check before scala zip, i.e. 
`assert(fields1.length == fields2.length)`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to