Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21073#discussion_r197056077 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -474,6 +473,221 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends Expression { + + private val MAX_MAP_SIZE: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def checkInputDataTypes(): TypeCheckResult = { + // check key types and value types separately to allow valueContainsNull to vary + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"The given input of function $prettyName should all be of type map, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName should all be the same type, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType.asInstanceOf[MapType].valueType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName should all be the same type, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: MapType = { + MapType( + keyType = children.headOption + .map(_.dataType.asInstanceOf[MapType].keyType).getOrElse(StringType), + valueType = children.headOption + .map(_.dataType.asInstanceOf[MapType].valueType).getOrElse(StringType), + valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + ) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > MAX_MAP_SIZE) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements" + + s" elements due to exceeding the map size limit" + + s" $MAX_MAP_SIZE.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |$arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + |$arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + |boolean ${ev.isNull} = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |${m.code} + |$argsName[$i] = ${m.value.code}; --- End diff -- nit: `${m.value}` instead of `${m.value.code}`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org