Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/21912#discussion_r209574445 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala --- @@ -735,70 +735,100 @@ class CodegenContext { } /** - * Generates code creating a [[UnsafeArrayData]]. + * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] based on + * given parameters. * * @param arrayName name of the array to create + * @param elementType data type of the elements in source array * @param numElements code representing the number of elements the array should contain - * @param elementType data type of the elements in the array * @param additionalErrorMessage string to include in the error message + * @param elementSize optional value which shows the size of an element of the allocated + * [[UnsafeArrayData]] or [[GenericArrayData]] + * + * @return code representing the allocation of [[ArrayData]] + * code representing a setter of an assignment for the generated array */ - def createUnsafeArray( + def createArrayData( arrayName: String, - numElements: String, elementType: DataType, - additionalErrorMessage: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") + numElements: String, + additionalErrorMessage: String, + elementSize: Option[Int] = None): (String, String) = { + val isPrimitiveType = if (elementSize.isDefined) { + false + } else { + CodeGenerator.isPrimitiveType(elementType) + } - s""" - |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | ${elementType.defaultSize}); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + - | "$additionalErrorMessage"); - |} - |byte[] $arrayBytes = new byte[(int)$arraySize]; - |UnsafeArrayData $arrayName = new UnsafeArrayData(); - |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - """.stripMargin + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val elemSize = if (elementSize.isDefined) { + elementSize.get + } else { + elementType.defaultSize + } + val arrayData = classOf[ArrayData].getName + val allocation = + s""" + |ArrayData $arrayName = $arrayData$$.MODULE$$.allocateArrayData( + | $elemSize, $numElements, $isPrimitiveType, "$additionalErrorMessage"); + """.stripMargin + + (allocation, setFunc) } /** - * Generates code creating a [[UnsafeArrayData]]. The generated code executes - * a provided fallback when the size of backing array would exceed the array size limit. - * @param arrayName a name of the array to create - * @param numElements a piece of code representing the number of elements the array should contain - * @param elementSize a size of an element in bytes - * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] - * and getting the backing array as a parameter - * @param fallbackCode a piece of code executed when the array size limit is exceeded + * Generates assignment code for a [[ArrayData]] + * + * @param arrayName name of the array to create + * @param dataType data type of the result array + * @param elementType data type of the elements in source array + * @param srcArray code representing the number of elements the array should contain + * @param setFunc string to include in the error message + * @param rhsValue an optionally specified expression for the right-hand side of the returning + * assignment + * @param checkForNull optional value which shows whether a nullcheck is required for + * the returning assignment + * + * @return code representing an assignment to each element of the [[ArrayData]], which requires + * a pair of destination and source loop index variables */ - def createUnsafeArrayWithFallback( + def createArrayAssignment( arrayName: String, - numElements: String, - elementSize: Int, - bodyCode: String => String, - fallbackCode: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - s""" - |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | $elementSize); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | $fallbackCode - |} else { - | final byte[] $arrayBytes = new byte[(int)$arraySize]; - | UnsafeArrayData $arrayName = new UnsafeArrayData(); - | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - | ${bodyCode(arrayBytes)} - |} - """.stripMargin + dataType: DataType, + elementType: DataType, + srcArray: String, + setFunc: String, + rhsValue: String = null, --- End diff -- if it's optional, we should use `Option[String]`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org