Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r176902161 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { + val arrayCheck = checkInputDataTypesAreArrays + if(arrayCheck.isFailure) arrayCheck + else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { + val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => + s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." + } + + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } + } + + override def dataType: ArrayType = + children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) + + + override protected def nullSafeEval(inputs: Seq[Any]): Any = { + val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) + new GenericArrayData(elements) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, arrays => { + val elementType = dataType.elementType + if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value) + } else { + genCodeForConcatOfComplexElements(ctx, arrays, ev.value) + } + }) + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + elements: Seq[String] + ) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = elements + .map(el => s"$variableName += $el.numElements();") + .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s) + (code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( + ctx: CodegenContext, + elementType: DataType, + elements: Seq[String], + arrayDataName: String + ): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + + val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + |${elementType.defaultSize} * $numElemName + |); + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val assignments = elements.map { el => + s""" + |for(int z = 0; z < $el.numElements(); z++) { --- End diff -- Stype: `for (`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org