Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21912#discussion_r209574445
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 ---
    @@ -735,70 +735,100 @@ class CodegenContext {
       }
     
       /**
    -   * Generates code creating a [[UnsafeArrayData]].
    +   * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] 
based on
    +   * given parameters.
        *
        * @param arrayName name of the array to create
    +   * @param elementType data type of the elements in source array
        * @param numElements code representing the number of elements the array 
should contain
    -   * @param elementType data type of the elements in the array
        * @param additionalErrorMessage string to include in the error message
    +   * @param elementSize optional value which shows the size of an element 
of the allocated
    +   *                    [[UnsafeArrayData]] or [[GenericArrayData]]
    +   *
    +   * @return code representing the allocation of [[ArrayData]]
    +   *         code representing a setter of an assignment for the generated 
array
        */
    -  def createUnsafeArray(
    +  def createArrayData(
           arrayName: String,
    -      numElements: String,
           elementType: DataType,
    -      additionalErrorMessage: String): String = {
    -    val arraySize = freshName("size")
    -    val arrayBytes = freshName("arrayBytes")
    +      numElements: String,
    +      additionalErrorMessage: String,
    +      elementSize: Option[Int] = None): (String, String) = {
    +    val isPrimitiveType = if (elementSize.isDefined) {
    +      false
    +    } else {
    +      CodeGenerator.isPrimitiveType(elementType)
    +    }
     
    -    s"""
    -       |long $arraySize = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
    -       |  $numElements,
    -       |  ${elementType.defaultSize});
    -       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
    -       |  throw new RuntimeException("Unsuccessful try create array with " 
+ $arraySize +
    -       |    " bytes of data due to exceeding the limit " +
    -       |    "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for 
UnsafeArrayData." +
    -       |    "$additionalErrorMessage");
    -       |}
    -       |byte[] $arrayBytes = new byte[(int)$arraySize];
    -       |UnsafeArrayData $arrayName = new UnsafeArrayData();
    -       |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
$numElements);
    -       |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
(int)$arraySize);
    -      """.stripMargin
    +    val setFunc = if (isPrimitiveType) {
    +      s"set${CodeGenerator.primitiveTypeName(elementType)}"
    +    } else {
    +      "update"
    +    }
    +
    +    val elemSize = if (elementSize.isDefined) {
    +      elementSize.get
    +    } else {
    +      elementType.defaultSize
    +    }
    +    val arrayData = classOf[ArrayData].getName
    +    val allocation =
    +      s"""
    +         |ArrayData $arrayName = $arrayData$$.MODULE$$.allocateArrayData(
    +         |  $elemSize, $numElements, $isPrimitiveType, 
"$additionalErrorMessage");
    +       """.stripMargin
    +
    +    (allocation, setFunc)
       }
     
       /**
    -   * Generates code creating a [[UnsafeArrayData]]. The generated code 
executes
    -   * a provided fallback when the size of backing array would exceed the 
array size limit.
    -   * @param arrayName a name of the array to create
    -   * @param numElements a piece of code representing the number of 
elements the array should contain
    -   * @param elementSize a size of an element in bytes
    -   * @param bodyCode a function generating code that fills up the 
[[UnsafeArrayData]]
    -   *                 and getting the backing array as a parameter
    -   * @param fallbackCode a piece of code executed when the array size 
limit is exceeded
    +   * Generates assignment code for a [[ArrayData]]
    +   *
    +   * @param arrayName name of the array to create
    +   * @param dataType data type of the result array
    +   * @param elementType data type of the elements in source array
    +   * @param srcArray code representing the number of elements the array 
should contain
    +   * @param setFunc string to include in the error message
    +   * @param rhsValue an optionally specified expression for the right-hand 
side of the returning
    +   *                 assignment
    +   * @param checkForNull optional value which shows whether a nullcheck is 
required for
    +   *                     the returning assignment
    +   *
    +   * @return code representing an assignment to each element of the 
[[ArrayData]], which requires
    +   *         a pair of destination and source loop index variables
        */
    -  def createUnsafeArrayWithFallback(
    +  def createArrayAssignment(
           arrayName: String,
    -      numElements: String,
    -      elementSize: Int,
    -      bodyCode: String => String,
    -      fallbackCode: String): String = {
    -    val arraySize = freshName("size")
    -    val arrayBytes = freshName("arrayBytes")
    -    s"""
    -       |final long $arraySize = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
    -       |  $numElements,
    -       |  $elementSize);
    -       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
    -       |  $fallbackCode
    -       |} else {
    -       |  final byte[] $arrayBytes = new byte[(int)$arraySize];
    -       |  UnsafeArrayData $arrayName = new UnsafeArrayData();
    -       |  Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
$numElements);
    -       |  $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
(int)$arraySize);
    -       |  ${bodyCode(arrayBytes)}
    -       |}
    -     """.stripMargin
    +      dataType: DataType,
    +      elementType: DataType,
    +      srcArray: String,
    +      setFunc: String,
    +      rhsValue: String = null,
    --- End diff --
    
    if it's optional, we should use `Option[String]`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to