Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21912#discussion_r213707288 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -385,107 +385,120 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { + val arrayData = ctx.freshName("arrayData") val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val elementSize = if (isKeyPrimitive && isValuePrimitive) { + Some(structSize + wordSize) + } else { + None + } + + val allocation = CodeGenerator.createArrayData(arrayData, childDataType.keyType, numElements, + s" $prettyName failed.", elementSize = elementSize) + val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + val genCodeForPrimitive = genCodeForPrimitiveElements( + ctx, arrayData, keys, values, ev.value, numElements, structSize) + s""" + |if ($arrayData instanceof UnsafeArrayData) { + | $genCodeForPrimitive + |} else { + | ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)} + |} + """.stripMargin } else { - genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}" } + s""" |final int $numElements = $c.numElements(); |final ArrayData $keys = $c.keyArray(); |final ArrayData $values = $c.valueArray(); + |$allocation |$code """.stripMargin }) } - private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + private def getKey(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.keyType, index) - private def getValue(varName: String) = { - CodeGenerator.getValue(varName, childDataType.valueType, "z") - } + private def getValue(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.valueType, index) private def genCodeForPrimitiveElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, - numElements: String): String = { - val unsafeRow = ctx.freshName("unsafeRow") + resultArrayData: String, + numElements: String, + structSize: Int): String = { val unsafeArrayData = ctx.freshName("unsafeArrayData") + val baseObject = ctx.freshName("baseObject") + val unsafeRow = ctx.freshName("unsafeRow") val structsOffset = ctx.freshName("structsOffset") + val offset = ctx.freshName("offset") + val z = ctx.freshName("z") val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET val wordSize = UnsafeRow.WORD_SIZE - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 - val structSizeAsLong = structSize + "L" - val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.valueType) - - val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" - val valueAssignmentChecked = if (childDataType.valueContainsNull) { - s""" - |if ($values.isNullAt(z)) { - | $unsafeRow.setNullAt(1); - |} else { - | $valueAssignment - |} - """.stripMargin - } else { - valueAssignment - } + val structSizeAsLong = s"${structSize}L" - val assignmentLoop = (byteArray: String) => - s""" - |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; - |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - |} - |$arrayData = $unsafeArrayData; - """.stripMargin + val setKey = + CodeGenerator.setArrayElement(unsafeRow, childDataType.keyType, "0", getKey(keys, z)) --- End diff -- @cloud-fan Good catch. We will use `setColumn` here. When I checked source files, it is not straightforward since there are two differences. 1. `value()` is called for `StructType`, `ArrayType`, and others in `setColumn` 2. `setDecimal()` is not supported in `Array` If we add one boolean value to distinguish `column` and `array`, we can unify them into one. Do we do this unification? If no, do we update `setColumn` to generate `if` statement for nullcheck like `setArrayElement`? Or, will we update `setColumn` in another PR?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org