Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/20938#discussion_r182534094 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,160 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + override def nullable: Boolean = child.nullable || dataType.containsNull + + override def dataType: ArrayType = { + child + .dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[ArrayType] + } + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(array: Any): Any = { + val elements = array.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val flattened = elements.flatMap( + _.asInstanceOf[ArrayData].toObjectArray(dataType.elementType) + ) + new GenericArrayData(flattened) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(dataType.elementType)) { + genCodeForConcatOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForConcatOfComplexElements(ctx, c, ev.value) + } + nullElementsProtection(ev, c, code) + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for(int z=0; z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if(!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |int $variableName = 0; + |for(int z=0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val elementType = dataType.elementType + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + | ${elementType.defaultSize} * $numElemName + |); + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; --- End diff -- Sorry for late comment. I think that it is fine to use `byte[]` for now.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org