Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21236#discussion_r186371562
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -118,6 +118,162 @@ case class MapValues(child: Expression)
       override def prettyName: String = "map_values"
     }
     
    +/**
    + * Returns an unordered array of all entries in the given map.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(map) - Returns an unordered array of all entries in the 
given map.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(map(1, 'a', 2, 'b'));
    +       [(1,"a"),(2,"b")]
    +  """,
    +  since = "2.4.0")
    +case class MapEntries(child: Expression) extends UnaryExpression with 
ExpectsInputTypes {
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
    +
    +  lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
    +
    +  override def dataType: DataType = {
    +    ArrayType(
    +      StructType(
    +        StructField("key", childDataType.keyType, false) ::
    +        StructField("value", childDataType.valueType, 
childDataType.valueContainsNull) ::
    +        Nil),
    +      false)
    +  }
    +
    +  override protected def nullSafeEval(input: Any): Any = {
    +    val childMap = input.asInstanceOf[MapData]
    +    val keys = childMap.keyArray()
    +    val values = childMap.valueArray()
    +    val length = childMap.numElements()
    +    val resultData = new Array[AnyRef](length)
    +    var i = 0;
    +    while (i < length) {
    +      val key = keys.get(i, childDataType.keyType)
    +      val value = values.get(i, childDataType.valueType)
    +      val row = new GenericInternalRow(Array[Any](key, value))
    +      resultData.update(i, row)
    +      i += 1
    +    }
    +    new GenericArrayData(resultData)
    +  }
    +
    +  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
    +    nullSafeCodeGen(ctx, ev, c => {
    +      val numElements = ctx.freshName("numElements")
    +      val keys = ctx.freshName("keys")
    +      val values = ctx.freshName("values")
    +      val isKeyPrimitive = 
CodeGenerator.isPrimitiveType(childDataType.keyType)
    +      val isValuePrimitive = 
CodeGenerator.isPrimitiveType(childDataType.valueType)
    +      val code = if (isKeyPrimitive && isValuePrimitive) {
    +        genCodeForPrimitiveElements(ctx, keys, values, ev.value, 
numElements)
    +      } else {
    +        genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
    +      }
    +      s"""
    +         |final int $numElements = $c.numElements();
    +         |final ArrayData $keys = $c.keyArray();
    +         |final ArrayData $values = $c.valueArray();
    +         |$code
    +       """.stripMargin
    +    })
    +  }
    +
    +  private def getKey(varName: String) = CodeGenerator.getValue(varName, 
childDataType.keyType, "z")
    +
    +  private def getValue(varName: String) = {
    +    CodeGenerator.getValue(varName, childDataType.valueType, "z")
    +  }
    +
    +  private def genCodeForPrimitiveElements(
    +      ctx: CodegenContext,
    +      keys: String,
    +      values: String,
    +      arrayData: String,
    +      numElements: String): String = {
    +    val byteArraySize = ctx.freshName("byteArraySize")
    +    val data = ctx.freshName("byteArray")
    +    val unsafeRow = ctx.freshName("unsafeRow")
    +    val structSize = ctx.freshName("structSize")
    +    val unsafeArrayData = ctx.freshName("unsafeArrayData")
    +    val structsOffset = ctx.freshName("structsOffset")
    +    val calculateArraySize = 
"UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
    +    val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
    +
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +    val longSize = LongType.defaultSize
    +    val keyTypeName = 
CodeGenerator.primitiveTypeName(childDataType.keyType)
    +    val valueTypeName = 
CodeGenerator.primitiveTypeName(childDataType.keyType)
    +
    +    val valueAssignment = s"$unsafeRow.set$valueTypeName(1, 
${getValue(values)});"
    +    val valueAssignmentChecked = if (childDataType.valueContainsNull) {
    +      s"""
    +         |if ($values.isNullAt(z)) {
    +         |  $unsafeRow.setNullAt(1);
    +         |} else {
    +         |  $valueAssignment
    +         |}
    +       """.stripMargin
    +    } else {
    +      valueAssignment
    +    }
    +
    +    s"""
    +       |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) 
+ longSize * 2};
    --- End diff --
    
    We can calculate `structSize` beforehand and inline it?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to