Github user mn-mikke commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20858#discussion_r177084478
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_contains"
     }
    +
    +/**
    + * Concatenates multiple arrays into one.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
    +       [1,2,3,4,5,6]
    +  """)
    +case class ConcatArrays(children: Seq[Expression]) extends Expression with 
NullSafeEvaluation {
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val arrayCheck = checkInputDataTypesAreArrays
    +    if(arrayCheck.isFailure) arrayCheck
    +    else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), 
s"function $prettyName")
    +  }
    +
    +  private def checkInputDataTypesAreArrays(): TypeCheckResult =
    +  {
    +    val mismatches = children.zipWithIndex.collect {
    +      case (child, idx) if !ArrayType.acceptsType(child.dataType) =>
    +        s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " +
    +          s"however, '${child.sql}' is of ${child.dataType.simpleString} 
type."
    +    }
    +
    +    if (mismatches.isEmpty) {
    +      TypeCheckResult.TypeCheckSuccess
    +    } else {
    +      TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
    +    }
    +  }
    +
    +  override def dataType: ArrayType =
    +    children
    +      .headOption.map(_.dataType.asInstanceOf[ArrayType])
    +      .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType])
    +
    +
    +  override protected def nullSafeEval(inputs: Seq[Any]): Any = {
    +    val elements = 
inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType))
    +    new GenericArrayData(elements)
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    nullSafeCodeGen(ctx, ev, arrays => {
    +      val elementType = dataType.elementType
    +      if (CodeGenerator.isPrimitiveType(elementType)) {
    +        genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, 
ev.value)
    +      } else {
    +        genCodeForConcatOfComplexElements(ctx, arrays, ev.value)
    +      }
    +    })
    +  }
    +
    +  private def genCodeForNumberOfElements(
    +    ctx: CodegenContext,
    +    elements: Seq[String]
    +  ) : (String, String) = {
    +    val variableName = ctx.freshName("numElements")
    +    val code = elements
    +      .map(el => s"$variableName += $el.numElements();")
    +      .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s)
    +    (code, variableName)
    +  }
    +
    +  private def genCodeForConcatOfPrimitiveElements(
    +    ctx: CodegenContext,
    +    elementType: DataType,
    +    elements: Seq[String],
    +    arrayDataName: String
    +  ): String = {
    +    val arrayName = ctx.freshName("array")
    +    val arraySizeName = ctx.freshName("size")
    +    val counter = ctx.freshName("counter")
    +    val tempArrayDataName = ctx.freshName("tempArrayData")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
elements)
    +
    +    val unsafeArraySizeInBytes = s"""
    +      |int $arraySizeName = 
UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
    +      
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
    +      |${elementType.defaultSize} * $numElemName
    +      |);
    +      """.stripMargin
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(elementType)
    +    val assignments = elements.map { el =>
    +      s"""
    +        |for(int z = 0; z < $el.numElements(); z++) {
    +        | if($el.isNullAt(z)) {
    +        |   $tempArrayDataName.setNullAt($counter);
    +        | } else {
    +        |   $tempArrayDataName.set$primitiveValueTypeName(
    +        |     $counter,
    +        |     $el.get$primitiveValueTypeName(z)
    +        |   );
    +        | }
    +        | $counter++;
    +        |}
    +        """.stripMargin
    +    }.mkString("\n")
    +
    +    s"""
    +      |$numElemCode
    +      |$unsafeArraySizeInBytes
    +      |byte[] $arrayName = new byte[$arraySizeName];
    +      |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
    +      |Platform.putLong($arrayName, $baseOffset, $numElemName);
    +      |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName);
    +      |int $counter = 0;
    +      |$assignments
    +      |$arrayDataName = $tempArrayDataName;
    +    """.stripMargin
    +
    +  }
    +
    +  private def genCodeForConcatOfComplexElements(
    +   ctx: CodegenContext,
    +   elements: Seq[String],
    +   arrayDataName: String
    +  ): String = {
    +    val genericArrayClass = classOf[GenericArrayData].getName
    +    val arrayName = ctx.freshName("arrayObject")
    +    val counter = ctx.freshName("counter")
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
elements)
    +
    +    val assignments = elements.map { el =>
    +      s"""
    +        |for(int z = 0; z < $el.numElements(); z++) {
    +        |  $arrayName[$counter] = $el.array()[z];
    +        |  $counter++;
    +        |}
    +     """.stripMargin
    +    }.mkString("\n")
    +
    +    s"""
    +      |$numElemCode
    +      |Object[] $arrayName = new Object[$numElemName];
    +      |int $counter = 0;
    +      |$assignments
    +      |$arrayDataName = new $genericArrayClass($arrayName);
    --- End diff --
    
    Really like this idea! I think it would require moving the complex type 
insertion logic from `InterprettedUnsafeProjection` directly to 
`UnsafeDataWriter` and introduce in that way write methods for complex type 
fields. I'm not sure whether this big refactoring task is still in the scope of 
this PR.
    
    Also see that we could improve codeGen of `CreateArray` in the same way.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to