Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/22017#discussion_r208878022 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala --- @@ -442,3 +442,191 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + @transient lazy val functionForEval: Expression = functionsForEval.head + + @transient lazy val (keyType, leftValueType, _) = + HigherOrderFunction.mapKeyValueArgumentType(left.dataType) + + @transient lazy val (_, rightValueType, _) = + HigherOrderFunction.mapKeyValueArgumentType(right.dataType) + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def inputs: Seq[Expression] = left :: right :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) => + TypeUtils.checkForOrderingExpr(k1, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with the same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + private def keyTypeSupportsEquals = keyType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = { + if (keyTypeSupportsEquals) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] --- End diff -- @mn-mikke we can use `LinkedHashMap` in order to preserve key order.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org