Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/22017#discussion_r209820348 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala --- @@ -496,3 +496,195 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + @transient lazy val functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType + + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType + + @transient lazy val keyType = + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + // Nothing to check since the data type of the lambda function can be anything. + override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess --- End diff -- I'd call `checkArgumentDataTypes()` here again.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org