panbingkun commented on code in PR #47984: URL: https://github.com/apache/spark/pull/47984#discussion_r1770945052
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala: ########## @@ -1525,6 +1526,126 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +/** + * This expression converts data of `ArrayData` to an array of java type. + * + * NOTE: When the data type of expression is `ArrayType`, and the expression is foldable, + * the `ConstantFolding` can do constant folding optimization automatically, + * (avoiding frequent calls to `ArrayData.to{XXX}Array()`). + */ +case class ToJavaArray(array: Expression) + extends UnaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def checkInputDataTypes(): TypeCheckResult = array.dataType match { + case ArrayType(_, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(array), + "inputType" -> toSQLType(array.dataType)) + ) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(array.dataType) + override def dataType: DataType = { + if (canPerformFast) { + elementType match { + case ByteType => ObjectType(classOf[Array[Byte]]) + case ShortType => ObjectType(classOf[Array[Short]]) + case IntegerType => ObjectType(classOf[Array[Int]]) + case LongType => ObjectType(classOf[Array[Long]]) + case FloatType => ObjectType(classOf[Array[Float]]) + case DoubleType => ObjectType(classOf[Array[Double]]) + } + } else if (isPrimitiveType) { + elementType match { + case BooleanType => ObjectType(classOf[Array[java.lang.Boolean]]) + case ByteType => ObjectType(classOf[Array[java.lang.Byte]]) + case ShortType => ObjectType(classOf[Array[java.lang.Short]]) + case IntegerType => ObjectType(classOf[Array[java.lang.Integer]]) + case LongType => ObjectType(classOf[Array[java.lang.Long]]) + case FloatType => ObjectType(classOf[Array[java.lang.Float]]) + case DoubleType => ObjectType(classOf[Array[java.lang.Double]]) + } + } else { + ObjectType(classOf[Array[Object]]) + } + } + + override def child: Expression = array + override def prettyName: String = "to_java_array" + + @transient lazy val elementType: DataType = + array.dataType.asInstanceOf[ArrayType].elementType + private def resultArrayElementNullable: Boolean = + array.dataType.asInstanceOf[ArrayType].containsNull + private def isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType) + private def canPerformFast: Boolean = + isPrimitiveType && elementType != BooleanType && !resultArrayElementNullable + + private def toJavaArray(array: Any): Any = { + val arrayData = array.asInstanceOf[ArrayData] + if (canPerformFast) { + elementType match { + case ByteType => arrayData.toByteArray() + case ShortType => arrayData.toShortArray() + case IntegerType => arrayData.toIntArray() + case LongType => arrayData.toLongArray() + case FloatType => arrayData.toFloatArray() + case DoubleType => arrayData.toDoubleArray() + } + } else if (isPrimitiveType) { + elementType match { + case BooleanType => arrayData.toArray[java.lang.Boolean](BooleanType) + case ByteType => arrayData.toArray[java.lang.Byte](ByteType) + case ShortType => arrayData.toArray[java.lang.Short](ShortType) + case IntegerType => arrayData.toArray[java.lang.Integer](IntegerType) + case LongType => arrayData.toArray[java.lang.Long](LongType) + case FloatType => arrayData.toArray[java.lang.Float](FloatType) + case DoubleType => arrayData.toArray[java.lang.Double](DoubleType) + } + } else { + arrayData.toObjectArray(elementType) + } + } + + private def toJavaArrayCodegen( + ctx: CodegenContext, + ev: ExprCode, + array: String): String = { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + if (canPerformFast) { + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s"""${ev.value} = $array.to${primitiveTypeName}Array();""" + } else if (isPrimitiveType) { + val boxedJavaType = CodeGenerator.boxedType(elementType) + val classTagTerm = ctx.addReferenceObj("classTagTerm", + ClassTag(SparkClassUtils.classForName(s"java.lang.$boxedJavaType"))) + s"""${ev.value} = ($boxedJavaType[]) $array.toArray($elementTypeTerm, $classTagTerm);""" + } else { + s"""${ev.value} = $array.toObjectArray($elementTypeTerm);""" + } + } + + override def nullSafeEval(array: Any): Any = { Review Comment: Based on the currently submitted version (use `Invoke` to implement `ToJavaArray`), I run benchmark 3 times and the performance is as follows: ```shell Running benchmark: array binary search Running case: has foldable optimize Stopped after 100 iterations, 23205 ms OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0 Apple M2 array binary search: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ has foldable optimize 223 232 18 44.8 22.3 1.0X Running benchmark: array binary search Running case: has foldable optimize Stopped after 100 iterations, 20181 ms OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0 Apple M2 array binary search: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ has foldable optimize 195 202 8 51.2 19.5 1.0X Running benchmark: array binary search Running case: has foldable optimize Stopped after 100 iterations, 20819 ms OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0 Apple M2 array binary search: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ has foldable optimize 202 208 8 49.4 20.2 1.0X ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org