Github user henryr commented on a diff in the pull request: https://github.com/apache/spark/pull/21073#discussion_r183559190 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -116,6 +118,154 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( +usage = "_FUNC_(map, ...) - Returns the union of all the given maps", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "c"], [3 -> "d"] + """) +case class MapConcat(children: Seq[Expression]) extends Expression + with CodegenFallback { + + override def checkInputDataTypes(): TypeCheckResult = { + // this check currently does not allow valueContainsNull to vary, + // and unfortunately none of the MapType toString methods include + // valueContainsNull for the error message + if (children.size < 2) { + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects at least two input maps.") + } else if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"The given input of function $prettyName should all be of type map, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName should all be the same type, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: MapType = { + children.headOption.map(_.dataType.asInstanceOf[MapType]) + .getOrElse(MapType(keyType = StringType, valueType = StringType)) + } + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val union = new util.LinkedHashMap[Any, Any]() + children.map(_.eval(input)).foreach { raw => + if (raw == null) { + return null + } + val map = raw.asInstanceOf[MapData] + map.foreach(dataType.keyType, dataType.valueType, (k, v) => + union.put(k, v) + ) + } + val (keyArray, valueArray) = union.entrySet().toArray().map { e => + val e2 = e.asInstanceOf[java.util.Map.Entry[Any, Any]] + (e2.getKey, e2.getValue) + }.unzip + new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(c => c.genCode(ctx)) + val keyType = children.head.dataType.asInstanceOf[MapType].keyType + val valueType = children.head.dataType.asInstanceOf[MapType].valueType + val mapRefArrayName = ctx.freshName("mapRefArray") + val unionMapName = ctx.freshName("union") + + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + val genericArrayDataClass = classOf[GenericArrayData].getName + val hashMapClass = classOf[util.LinkedHashMap[Any, Any]].getName + val entryClass = classOf[util.Map.Entry[Any, Any]].getName + + val init = + s""" + |$mapDataClass[] $mapRefArrayName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull} = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + val initCode = mapCodes(i).code + val valueVarName = mapCodes(i).value.code + s""" + |$initCode + |$mapRefArrayName[$i] = $valueVarName; + |if ($valueVarName == null) { + | ${ev.isNull} = true; + |} + """.stripMargin + }.mkString("\n") + + val index1Name = ctx.freshName("idx1") + val index2Name = ctx.freshName("idx2") + val mapDataName = ctx.freshName("m") + val kaName = ctx.freshName("ka") + val vaName = ctx.freshName("va") + val keyName = ctx.freshName("key") + val valueName = ctx.freshName("value") + val isNullCheckName = ctx.freshName("isNull") + + val mapMerge = + s""" + |$hashMapClass<Object, Object> $unionMapName = new $hashMapClass<Object, Object>(); + |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { + | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; + | $arrayDataClass $kaName = $mapDataName.keyArray(); + | $arrayDataClass $vaName = $mapDataName.valueArray(); + | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { + | Object $keyName = ${CodeGenerator.getValue(kaName, keyType, index2Name)}; + | Object $valueName = ${CodeGenerator.getValue(vaName, valueType, index2Name)}; + | $unionMapName.put($keyName, $valueName); + | } + |} + """.stripMargin + + val mergedKeyArrayName = ctx.freshName("keyArray") + val mergedValueArrayName = ctx.freshName("valueArray") + val entrySetName = ctx.freshName("entrySet") + val createMapData = + s""" + |Object[] $entrySetName = $unionMapName.entrySet().toArray(); + |Object[] $mergedKeyArrayName = new Object[$entrySetName.length]; + |Object[] $mergedValueArrayName = new Object[$entrySetName.length]; + |for (int $index1Name = 0; $index1Name < $entrySetName.length; $index1Name++) { + | $entryClass<Object, Object> entry = + | ($entryClass<Object, Object>) $entrySetName[$index1Name]; + | $mergedKeyArrayName[$index1Name] = (Object) entry.getKey(); --- End diff -- are the casts to `Object` necessary?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org