Repository: spark
Updated Branches:
  refs/heads/master 2e3abdff2 -> b804ca577


[SPARK-23908][SQL][FOLLOW-UP] Rename inputs to arguments, and add argument type 
check.

## What changes were proposed in this pull request?

This is a follow-up pr of #21954 to address comments.

- Rename ambiguous name `inputs` to `arguments`.
- Add argument type check and remove hacky workaround.
- Address other small comments.

## How was this patch tested?

Existing tests and some additional tests.

Closes #22075 from ueshin/issues/SPARK-23908/fup1.

Authored-by: Takuya UESHIN <ues...@databricks.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b804ca57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b804ca57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b804ca57

Branch: refs/heads/master
Commit: b804ca57718ad1568458d8185c8c30118be8275f
Parents: 2e3abdf
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Mon Aug 13 20:58:29 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Mon Aug 13 20:58:29 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  14 ++
 .../analysis/higherOrderFunctions.scala         |  12 +-
 .../expressions/ExpectsInputTypes.scala         |  16 +-
 .../expressions/higherOrderFunctions.scala      | 181 ++++++++++---------
 .../spark/sql/catalyst/plans/PlanTest.scala     |   2 +-
 .../spark/sql/DataFrameFunctionsSuite.scala     |  25 +++
 6 files changed, 152 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 4addc83..6a91d55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -90,6 +90,20 @@ trait CheckAnalysis extends PredicateHelper {
         u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")
 
       case operator: LogicalPlan =>
+        // Check argument data types of higher-order functions downwards first.
+        // If the arguments of the higher-order functions are resolved but the 
type check fails,
+        // the argument functions will not get resolved, but we should report 
the argument type
+        // check failure instead of claiming the argument functions are 
unresolved.
+        operator transformExpressionsDown {
+          case hof: HigherOrderFunction
+              if hof.argumentsResolved && 
hof.checkArgumentDataTypes().isFailure =>
+            hof.checkArgumentDataTypes() match {
+              case TypeCheckResult.TypeCheckFailure(message) =>
+                hof.failAnalysis(
+                  s"cannot resolve '${hof.sql}' due to argument data type 
mismatch: $message")
+            }
+        }
+
         operator transformExpressionsUp {
           case a: Attribute if !a.resolved =>
             val from = operator.inputSet.map(_.qualifiedName).mkString(", ")

http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index 5e2029c..dd08190 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -95,15 +95,15 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
    */
   private def createLambda(
       e: Expression,
-      partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match {
+      argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
     case f: LambdaFunction if f.bound => f
 
     case LambdaFunction(function, names, _) =>
-      if (names.size != partialArguments.size) {
+      if (names.size != argInfo.size) {
         e.failAnalysis(
           s"The number of lambda function arguments '${names.size}' does not " 
+
             "match the number of arguments expected by the higher order 
function " +
-            s"'${partialArguments.size}'.")
+            s"'${argInfo.size}'.")
       }
 
       if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) {
@@ -111,7 +111,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
           "Lambda function arguments should not have names that are 
semantically the same.")
       }
 
-      val arguments = partialArguments.zip(names).map {
+      val arguments = argInfo.zip(names).map {
         case ((dataType, nullable), ne) =>
           NamedLambdaVariable(ne.name, dataType, nullable)
       }
@@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
       // create a lambda function with default parameters because this is 
expected by the higher
       // order function. Note that we hide the lambda variables produced by 
this function in order
       // to prevent accidental naming collisions.
-      val arguments = partialArguments.zipWithIndex.map {
+      val arguments = argInfo.zipWithIndex.map {
         case ((dataType, nullable), i) =>
           NamedLambdaVariable(s"col$i", dataType, nullable)
       }
@@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
   private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): 
Expression = e match {
     case _ if e.resolved => e
 
-    case h: HigherOrderFunction if h.inputResolved =>
+    case h: HigherOrderFunction if h.argumentsResolved && 
h.checkArgumentDataTypes().isSuccess =>
       h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))
 
     case l: LambdaFunction if !l.bound =>

http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index d8f046c..981ce0b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression {
   def inputTypes: Seq[AbstractDataType]
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    val mismatches = children.zip(inputTypes).zipWithIndex.collect {
-      case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
+    ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
+  }
+}
+
+object ExpectsInputTypes {
+
+  def checkInputDataTypes(
+      inputs: Seq[Expression],
+      inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
+    val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
+      case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
         s"argument ${idx + 1} requires ${expected.simpleString} type, " +
-          s"however, '${child.sql}' is of ${child.dataType.catalogString} 
type."
+          s"however, '${input.sql}' is of ${input.dataType.catalogString} 
type."
     }
 
     if (mismatches.isEmpty) {
@@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression {
   }
 }
 
-
 /**
  * A mixin for the analyzer to perform implicit type casting using
  * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].

http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 7f8203a..5d1b8c4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -35,8 +35,8 @@ case class NamedLambdaVariable(
     name: String,
     dataType: DataType,
     nullable: Boolean,
-    value: AtomicReference[Any] = new AtomicReference(),
-    exprId: ExprId = NamedExpression.newExprId)
+    exprId: ExprId = NamedExpression.newExprId,
+    value: AtomicReference[Any] = new AtomicReference())
   extends LeafExpression
   with NamedExpression
   with CodegenFallback {
@@ -44,7 +44,7 @@ case class NamedLambdaVariable(
   override def qualifier: Seq[String] = Seq.empty
 
   override def newInstance(): NamedExpression =
-    copy(value = new AtomicReference(), exprId = NamedExpression.newExprId)
+    copy(exprId = NamedExpression.newExprId, value = new AtomicReference())
 
   override def toAttribute: Attribute = {
     AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, 
Seq.empty)
@@ -88,30 +88,45 @@ object LambdaFunction {
  * A higher order function takes one or more (lambda) functions and applies 
these to some objects.
  * The function produces a number of variables which can be consumed by some 
lambda function.
  */
-trait HigherOrderFunction extends Expression {
+trait HigherOrderFunction extends Expression with ExpectsInputTypes {
 
-  override def children: Seq[Expression] = inputs ++ functions
+  override def children: Seq[Expression] = arguments ++ functions
 
   /**
-   * Inputs to the higher ordered function.
+   * Arguments of the higher ordered function.
    */
-  def inputs: Seq[Expression]
+  def arguments: Seq[Expression]
+
+  def argumentTypes: Seq[AbstractDataType]
 
   /**
-   * All inputs have been resolved. This means that the types and nullabilty 
of (most of) the
+   * All arguments have been resolved. This means that the types and 
nullabilty of (most of) the
    * lambda function arguments is known, and that we can start binding the 
lambda functions.
    */
-  lazy val inputResolved: Boolean = inputs.forall(_.resolved)
+  lazy val argumentsResolved: Boolean = arguments.forall(_.resolved)
+
+  /**
+   * Checks the argument data types, returns `TypeCheckResult.success` if it's 
valid,
+   * or returns a `TypeCheckResult` with an error message if invalid.
+   * Note: it's not valid to call this method until `argumentsResolved == 
true`.
+   */
+  def checkArgumentDataTypes(): TypeCheckResult = {
+    ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes)
+  }
 
   /**
    * Functions applied by the higher order function.
    */
   def functions: Seq[Expression]
 
+  def functionTypes: Seq[AbstractDataType]
+
+  override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ 
functionTypes
+
   /**
    * All inputs must be resolved and all functions must be resolved lambda 
functions.
    */
-  override lazy val resolved: Boolean = inputResolved && functions.forall {
+  override lazy val resolved: Boolean = argumentsResolved && functions.forall {
     case l: LambdaFunction => l.resolved
     case _ => false
   }
@@ -123,6 +138,8 @@ trait HigherOrderFunction extends Expression {
    */
   def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): 
HigherOrderFunction
 
+  // Make sure the lambda variables refer the same instances as of arguments 
for case that the
+  // variables in instantiated separately during serialization or for some 
reason.
   @transient lazy val functionsForEval: Seq[Expression] = functions.map {
     case LambdaFunction(function, arguments, hidden) =>
       val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap
@@ -133,51 +150,38 @@ trait HigherOrderFunction extends Expression {
   }
 }
 
-object HigherOrderFunction {
-
-  def arrayArgumentType(dt: DataType): (DataType, Boolean) = {
-    dt match {
-      case ArrayType(elementType, containsNull) => (elementType, containsNull)
-      case _ =>
-        val ArrayType(elementType, containsNull) = 
ArrayType.defaultConcreteType
-        (elementType, containsNull)
-    }
-  }
-
-  def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = 
dt match {
-    case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
-    case _ =>
-      val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
-      (kType, vType, vContainsNull)
-  }
-}
-
 /**
  * Trait for functions having as input one argument and one function.
  */
-trait SimpleHigherOrderFunction extends HigherOrderFunction with 
ExpectsInputTypes {
+trait SimpleHigherOrderFunction extends HigherOrderFunction  {
+
+  def argument: Expression
 
-  def input: Expression
+  override def arguments: Seq[Expression] = argument :: Nil
 
-  override def inputs: Seq[Expression] = input :: Nil
+  def argumentType: AbstractDataType
+
+  override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil
 
   def function: Expression
 
   override def functions: Seq[Expression] = function :: Nil
 
-  def expectingFunctionType: AbstractDataType = AnyDataType
+  def functionType: AbstractDataType = AnyDataType
+
+  override def functionTypes: Seq[AbstractDataType] = functionType :: Nil
 
-  @transient lazy val functionForEval: Expression = functionsForEval.head
+  def functionForEval: Expression = functionsForEval.head
 
   /**
    * Called by [[eval]]. If a subclass keeps the default nullability, it can 
override this method
    * in order to save null-check code.
    */
-  protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
+  protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any =
     sys.error(s"UnaryHigherOrderFunction must override either eval or 
nullSafeEval")
 
   override def eval(inputRow: InternalRow): Any = {
-    val value = input.eval(inputRow)
+    val value = argument.eval(inputRow)
     if (value == null) {
       null
     } else {
@@ -187,11 +191,11 @@ trait SimpleHigherOrderFunction extends 
HigherOrderFunction with ExpectsInputTyp
 }
 
 trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
-  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, 
expectingFunctionType)
+  override def argumentType: AbstractDataType = ArrayType
 }
 
 trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
-  override def inputTypes: Seq[AbstractDataType] = Seq(MapType, 
expectingFunctionType)
+  override def argumentType: AbstractDataType = MapType
 }
 
 /**
@@ -209,21 +213,21 @@ trait MapBasedSimpleHigherOrderFunction extends 
SimpleHigherOrderFunction {
   """,
   since = "2.4.0")
 case class ArrayTransform(
-    input: Expression,
+    argument: Expression,
     function: Expression)
   extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
 
-  override def nullable: Boolean = input.nullable
+  override def nullable: Boolean = argument.nullable
 
   override def dataType: ArrayType = ArrayType(function.dataType, 
function.nullable)
 
   override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayTransform = {
-    val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
+    val ArrayType(elementType, containsNull) = argument.dataType
     function match {
       case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
-        copy(function = f(function, elem :: (IntegerType, false) :: Nil))
+        copy(function = f(function, (elementType, containsNull) :: 
(IntegerType, false) :: Nil))
       case _ =>
-        copy(function = f(function, elem :: Nil))
+        copy(function = f(function, (elementType, containsNull) :: Nil))
     }
   }
 
@@ -237,8 +241,8 @@ case class ArrayTransform(
     (elementVar, indexVar)
   }
 
-  override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
-    val arr = inputValue.asInstanceOf[ArrayData]
+  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+    val arr = argumentValue.asInstanceOf[ArrayData]
     val f = functionForEval
     val result = new GenericArrayData(new Array[Any](arr.numElements))
     var i = 0
@@ -268,7 +272,7 @@ examples = """
   """,
 since = "2.4.0")
 case class MapFilter(
-    input: Expression,
+    argument: Expression,
     function: Expression)
   extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
 
@@ -277,17 +281,16 @@ case class MapFilter(
     (args.head.asInstanceOf[NamedLambdaVariable], 
args.tail.head.asInstanceOf[NamedLambdaVariable])
   }
 
-  @transient val (keyType, valueType, valueContainsNull) =
-    HigherOrderFunction.mapKeyValueArgumentType(input.dataType)
+  @transient lazy val MapType(keyType, valueType, valueContainsNull) = 
argument.dataType
 
   override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): MapFilter = {
     copy(function = f(function, (keyType, false) :: (valueType, 
valueContainsNull) :: Nil))
   }
 
-  override def nullable: Boolean = input.nullable
+  override def nullable: Boolean = argument.nullable
 
-  override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
-    val m = value.asInstanceOf[MapData]
+  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+    val m = argumentValue.asInstanceOf[MapData]
     val f = functionForEval
     val retKeys = new mutable.ListBuffer[Any]
     val retValues = new mutable.ListBuffer[Any]
@@ -302,9 +305,9 @@ case class MapFilter(
     ArrayBasedMapData(retKeys.toArray, retValues.toArray)
   }
 
-  override def dataType: DataType = input.dataType
+  override def dataType: DataType = argument.dataType
 
-  override def expectingFunctionType: AbstractDataType = BooleanType
+  override def functionType: AbstractDataType = BooleanType
 
   override def prettyName: String = "map_filter"
 }
@@ -321,25 +324,25 @@ case class MapFilter(
   """,
   since = "2.4.0")
 case class ArrayFilter(
-    input: Expression,
+    argument: Expression,
     function: Expression)
   extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
 
-  override def nullable: Boolean = input.nullable
+  override def nullable: Boolean = argument.nullable
 
-  override def dataType: DataType = input.dataType
+  override def dataType: DataType = argument.dataType
 
-  override def expectingFunctionType: AbstractDataType = BooleanType
+  override def functionType: AbstractDataType = BooleanType
 
   override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayFilter = {
-    val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
-    copy(function = f(function, elem :: Nil))
+    val ArrayType(elementType, containsNull) = argument.dataType
+    copy(function = f(function, (elementType, containsNull) :: Nil))
   }
 
   @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), 
_) = function
 
-  override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
-    val arr = value.asInstanceOf[ArrayData]
+  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+    val arr = argumentValue.asInstanceOf[ArrayData]
     val f = functionForEval
     val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
     var i = 0
@@ -368,25 +371,25 @@ case class ArrayFilter(
   """,
   since = "2.4.0")
 case class ArrayExists(
-    input: Expression,
+    argument: Expression,
     function: Expression)
   extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
 
-  override def nullable: Boolean = input.nullable
+  override def nullable: Boolean = argument.nullable
 
   override def dataType: DataType = BooleanType
 
-  override def expectingFunctionType: AbstractDataType = BooleanType
+  override def functionType: AbstractDataType = BooleanType
 
   override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayExists = {
-    val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
-    copy(function = f(function, elem :: Nil))
+    val ArrayType(elementType, containsNull) = argument.dataType
+    copy(function = f(function, (elementType, containsNull) :: Nil))
   }
 
   @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), 
_) = function
 
-  override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
-    val arr = value.asInstanceOf[ArrayData]
+  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+    val arr = argumentValue.asInstanceOf[ArrayData]
     val f = functionForEval
     var exists = false
     var i = 0
@@ -422,45 +425,49 @@ case class ArrayExists(
   """,
   since = "2.4.0")
 case class ArrayAggregate(
-    input: Expression,
+    argument: Expression,
     zero: Expression,
     merge: Expression,
     finish: Expression)
   extends HigherOrderFunction with CodegenFallback {
 
-  def this(input: Expression, zero: Expression, merge: Expression) = {
-    this(input, zero, merge, LambdaFunction.identity)
+  def this(argument: Expression, zero: Expression, merge: Expression) = {
+    this(argument, zero, merge, LambdaFunction.identity)
   }
 
-  override def inputs: Seq[Expression] = input :: zero :: Nil
+  override def arguments: Seq[Expression] = argument :: zero :: Nil
+
+  override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType 
:: Nil
 
   override def functions: Seq[Expression] = merge :: finish :: Nil
 
-  override def nullable: Boolean = input.nullable || finish.nullable
+  override def functionTypes: Seq[AbstractDataType] = zero.dataType :: 
AnyDataType :: Nil
+
+  override def nullable: Boolean = argument.nullable || finish.nullable
 
   override def dataType: DataType = finish.dataType
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (!ArrayType.acceptsType(input.dataType)) {
-      TypeCheckResult.TypeCheckFailure(
-        s"argument 1 requires ${ArrayType.simpleString} type, " +
-          s"however, '${input.sql}' is of ${input.dataType.catalogString} 
type.")
-    } else if (!DataType.equalsStructurally(
-        zero.dataType, merge.dataType, ignoreNullability = true)) {
-      TypeCheckResult.TypeCheckFailure(
-        s"argument 3 requires ${zero.dataType.simpleString} type, " +
-          s"however, '${merge.sql}' is of ${merge.dataType.catalogString} 
type.")
-    } else {
-      TypeCheckResult.TypeCheckSuccess
+    checkArgumentDataTypes() match {
+      case TypeCheckResult.TypeCheckSuccess =>
+        if (!DataType.equalsStructurally(
+            zero.dataType, merge.dataType, ignoreNullability = true)) {
+          TypeCheckResult.TypeCheckFailure(
+            s"argument 3 requires ${zero.dataType.simpleString} type, " +
+              s"however, '${merge.sql}' is of ${merge.dataType.catalogString} 
type.")
+        } else {
+          TypeCheckResult.TypeCheckSuccess
+        }
+      case failure => failure
     }
   }
 
   override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayAggregate = {
     // Be very conservative with nullable. We cannot be sure that the 
accumulator does not
     // evaluate to null. So we always set nullable to true here.
-    val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
+    val ArrayType(elementType, containsNull) = argument.dataType
     val acc = zero.dataType -> true
-    val newMerge = f(merge, acc :: elem :: Nil)
+    val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil)
     val newFinish = f(finish, acc :: Nil)
     copy(merge = newMerge, finish = newFinish)
   }
@@ -470,7 +477,7 @@ case class ArrayAggregate(
   @transient lazy val LambdaFunction(_, Seq(accForFinishVar: 
NamedLambdaVariable), _) = finish
 
   override def eval(input: InternalRow): Any = {
-    val arr = this.input.eval(input).asInstanceOf[ArrayData]
+    val arr = argument.eval(input).asInstanceOf[ArrayData]
     if (arr == null) {
       null
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 9e95b19..67740c3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -81,7 +81,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite =>
       case ae: AggregateExpression =>
         ae.copy(resultId = ExprId(0))
       case lv: NamedLambdaVariable =>
-        lv.copy(value = null, exprId = ExprId(0))
+        lv.copy(exprId = ExprId(0), value = null)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 2c4238e..6401e3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1852,6 +1852,11 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df.selectExpr("transform(i, x -> x)")
     }
     assert(ex2.getMessage.contains("data type mismatch: argument 1 requires 
array type"))
+
+    val ex3 = intercept[AnalysisException] {
+      df.selectExpr("transform(a, x -> x)")
+    }
+    assert(ex3.getMessage.contains("cannot resolve '`a`'"))
   }
 
   test("map_filter") {
@@ -1898,6 +1903,11 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df.selectExpr("map_filter(i, (k, v) -> k > v)")
     }
     assert(ex3.getMessage.contains("data type mismatch: argument 1 requires 
map type"))
+
+    val ex4 = intercept[AnalysisException] {
+      df.selectExpr("map_filter(a, (k, v) -> k > v)")
+    }
+    assert(ex4.getMessage.contains("cannot resolve '`a`'"))
   }
 
   test("filter function - array for primitive type not containing null") {
@@ -1994,6 +2004,11 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df.selectExpr("filter(s, x -> x)")
     }
     assert(ex3.getMessage.contains("data type mismatch: argument 2 requires 
boolean type"))
+
+    val ex4 = intercept[AnalysisException] {
+      df.selectExpr("filter(a, x -> x)")
+    }
+    assert(ex4.getMessage.contains("cannot resolve '`a`'"))
   }
 
   test("exists function - array for primitive type not containing null") {
@@ -2090,6 +2105,11 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df.selectExpr("exists(s, x -> x)")
     }
     assert(ex3.getMessage.contains("data type mismatch: argument 2 requires 
boolean type"))
+
+    val ex4 = intercept[AnalysisException] {
+      df.selectExpr("exists(a, x -> x)")
+    }
+    assert(ex4.getMessage.contains("cannot resolve '`a`'"))
   }
 
   test("aggregate function - array for primitive type not containing null") {
@@ -2211,6 +2231,11 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
     }
     assert(ex4.getMessage.contains("data type mismatch: argument 3 requires 
int type"))
+
+    val ex5 = intercept[AnalysisException] {
+      df.selectExpr("aggregate(a, 0, (acc, x) -> x)")
+    }
+    assert(ex5.getMessage.contains("cannot resolve '`a`'"))
   }
 
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to