Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21912#discussion_r213541643
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -385,107 +385,120 @@ case class MapEntries(child: Expression) extends 
UnaryExpression with ExpectsInp
     
       override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
         nullSafeCodeGen(ctx, ev, c => {
    +      val arrayData = ctx.freshName("arrayData")
           val numElements = ctx.freshName("numElements")
           val keys = ctx.freshName("keys")
           val values = ctx.freshName("values")
           val isKeyPrimitive = 
CodeGenerator.isPrimitiveType(childDataType.keyType)
           val isValuePrimitive = 
CodeGenerator.isPrimitiveType(childDataType.valueType)
    +
    +      val wordSize = UnsafeRow.WORD_SIZE
    +      val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize 
* 2
    +      val elementSize = if (isKeyPrimitive && isValuePrimitive) {
    +        Some(structSize + wordSize)
    +      } else {
    +        None
    +      }
    +
    +      val allocation = CodeGenerator.createArrayData(arrayData, 
childDataType.keyType, numElements,
    +        s" $prettyName failed.", elementSize = elementSize)
    +
           val code = if (isKeyPrimitive && isValuePrimitive) {
    -        genCodeForPrimitiveElements(ctx, keys, values, ev.value, 
numElements)
    +        val genCodeForPrimitive = genCodeForPrimitiveElements(
    +          ctx, arrayData, keys, values, ev.value, numElements, structSize)
    +        s"""
    +           |if ($arrayData instanceof UnsafeArrayData) {
    +           |  $genCodeForPrimitive
    +           |} else {
    +           |  ${genCodeForAnyElements(ctx, arrayData, keys, values, 
ev.value, numElements)}
    +           |}
    +         """.stripMargin
           } else {
    -        genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
    +        s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, 
numElements)}"
           }
    +
           s"""
              |final int $numElements = $c.numElements();
              |final ArrayData $keys = $c.keyArray();
              |final ArrayData $values = $c.valueArray();
    +         |$allocation
              |$code
            """.stripMargin
         })
       }
     
    -  private def getKey(varName: String) = CodeGenerator.getValue(varName, 
childDataType.keyType, "z")
    +  private def getKey(varName: String, index: String) =
    +    CodeGenerator.getValue(varName, childDataType.keyType, index)
     
    -  private def getValue(varName: String) = {
    -    CodeGenerator.getValue(varName, childDataType.valueType, "z")
    -  }
    +  private def getValue(varName: String, index: String) =
    +    CodeGenerator.getValue(varName, childDataType.valueType, index)
     
       private def genCodeForPrimitiveElements(
           ctx: CodegenContext,
    +      arrayData: String,
           keys: String,
           values: String,
    -      arrayData: String,
    -      numElements: String): String = {
    -    val unsafeRow = ctx.freshName("unsafeRow")
    +      resultArrayData: String,
    +      numElements: String,
    +      structSize: Int): String = {
         val unsafeArrayData = ctx.freshName("unsafeArrayData")
    +    val baseObject = ctx.freshName("baseObject")
    +    val unsafeRow = ctx.freshName("unsafeRow")
         val structsOffset = ctx.freshName("structsOffset")
    +    val offset = ctx.freshName("offset")
    +    val z = ctx.freshName("z")
         val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
     
         val baseOffset = Platform.BYTE_ARRAY_OFFSET
         val wordSize = UnsafeRow.WORD_SIZE
    -    val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 
2
    -    val structSizeAsLong = structSize + "L"
    -    val keyTypeName = 
CodeGenerator.primitiveTypeName(childDataType.keyType)
    -    val valueTypeName = 
CodeGenerator.primitiveTypeName(childDataType.valueType)
    -
    -    val valueAssignment = s"$unsafeRow.set$valueTypeName(1, 
${getValue(values)});"
    -    val valueAssignmentChecked = if (childDataType.valueContainsNull) {
    -      s"""
    -         |if ($values.isNullAt(z)) {
    -         |  $unsafeRow.setNullAt(1);
    -         |} else {
    -         |  $valueAssignment
    -         |}
    -       """.stripMargin
    -    } else {
    -      valueAssignment
    -    }
    +    val structSizeAsLong = s"${structSize}L"
     
    -    val assignmentLoop = (byteArray: String) =>
    -      s"""
    -         |final int $structsOffset = $calculateHeader($numElements) + 
$numElements * $wordSize;
    -         |UnsafeRow $unsafeRow = new UnsafeRow(2);
    -         |for (int z = 0; z < $numElements; z++) {
    -         |  long offset = $structsOffset + z * $structSizeAsLong;
    -         |  $unsafeArrayData.setLong(z, (offset << 32) + 
$structSizeAsLong);
    -         |  $unsafeRow.pointTo($byteArray, $baseOffset + offset, 
$structSize);
    -         |  $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
    -         |  $valueAssignmentChecked
    -         |}
    -         |$arrayData = $unsafeArrayData;
    -       """.stripMargin
    +    val setKey =
    +      CodeGenerator.setArrayElement(unsafeRow, childDataType.keyType, "0", 
getKey(keys, z))
    --- End diff --
    
    here we are setting value for row not array. Can we unify `setColumn` and 
`setArrayElement`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to