panbingkun commented on code in PR #47984:
URL: https://github.com/apache/spark/pull/47984#discussion_r1770945052


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -1525,6 +1526,126 @@ case class ArrayContains(left: Expression, right: 
Expression)
     copy(left = newLeft, right = newRight)
 }
 
+/**
+ * This expression converts data of `ArrayData` to an array of java type.
+ *
+ * NOTE: When the data type of expression is `ArrayType`, and the expression 
is foldable,
+ * the `ConstantFolding` can do constant folding optimization automatically,
+ * (avoiding frequent calls to `ArrayData.to{XXX}Array()`).
+ */
+case class ToJavaArray(array: Expression)
+  extends UnaryExpression
+  with ImplicitCastInputTypes
+  with NullIntolerant
+  with QueryErrorsBase {
+
+  override def checkInputDataTypes(): TypeCheckResult = array.dataType match {
+    case ArrayType(_, _) =>
+      TypeCheckResult.TypeCheckSuccess
+    case _ =>
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> ordinalNumber(0),
+          "requiredType" -> toSQLType(ArrayType),
+          "inputSql" -> toSQLExpr(array),
+          "inputType" -> toSQLType(array.dataType))
+      )
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(array.dataType)
+  override def dataType: DataType = {
+    if (canPerformFast) {
+      elementType match {
+        case ByteType => ObjectType(classOf[Array[Byte]])
+        case ShortType => ObjectType(classOf[Array[Short]])
+        case IntegerType => ObjectType(classOf[Array[Int]])
+        case LongType => ObjectType(classOf[Array[Long]])
+        case FloatType => ObjectType(classOf[Array[Float]])
+        case DoubleType => ObjectType(classOf[Array[Double]])
+      }
+    } else if (isPrimitiveType) {
+      elementType match {
+        case BooleanType => ObjectType(classOf[Array[java.lang.Boolean]])
+        case ByteType => ObjectType(classOf[Array[java.lang.Byte]])
+        case ShortType => ObjectType(classOf[Array[java.lang.Short]])
+        case IntegerType => ObjectType(classOf[Array[java.lang.Integer]])
+        case LongType => ObjectType(classOf[Array[java.lang.Long]])
+        case FloatType => ObjectType(classOf[Array[java.lang.Float]])
+        case DoubleType => ObjectType(classOf[Array[java.lang.Double]])
+      }
+    } else {
+      ObjectType(classOf[Array[Object]])
+    }
+  }
+
+  override def child: Expression = array
+  override def prettyName: String = "to_java_array"
+
+  @transient lazy val elementType: DataType =
+    array.dataType.asInstanceOf[ArrayType].elementType
+  private def resultArrayElementNullable: Boolean =
+    array.dataType.asInstanceOf[ArrayType].containsNull
+  private def isPrimitiveType: Boolean = 
CodeGenerator.isPrimitiveType(elementType)
+  private def canPerformFast: Boolean =
+    isPrimitiveType && elementType != BooleanType && 
!resultArrayElementNullable
+
+  private def toJavaArray(array: Any): Any = {
+    val arrayData = array.asInstanceOf[ArrayData]
+    if (canPerformFast) {
+      elementType match {
+        case ByteType => arrayData.toByteArray()
+        case ShortType => arrayData.toShortArray()
+        case IntegerType => arrayData.toIntArray()
+        case LongType => arrayData.toLongArray()
+        case FloatType => arrayData.toFloatArray()
+        case DoubleType => arrayData.toDoubleArray()
+      }
+    } else if (isPrimitiveType) {
+      elementType match {
+        case BooleanType => arrayData.toArray[java.lang.Boolean](BooleanType)
+        case ByteType => arrayData.toArray[java.lang.Byte](ByteType)
+        case ShortType => arrayData.toArray[java.lang.Short](ShortType)
+        case IntegerType => arrayData.toArray[java.lang.Integer](IntegerType)
+        case LongType => arrayData.toArray[java.lang.Long](LongType)
+        case FloatType => arrayData.toArray[java.lang.Float](FloatType)
+        case DoubleType => arrayData.toArray[java.lang.Double](DoubleType)
+      }
+    } else {
+      arrayData.toObjectArray(elementType)
+    }
+  }
+
+  private def toJavaArrayCodegen(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      array: String): String = {
+    val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
+    if (canPerformFast) {
+      val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
+      s"""${ev.value} = $array.to${primitiveTypeName}Array();"""
+    } else if (isPrimitiveType) {
+      val boxedJavaType = CodeGenerator.boxedType(elementType)
+      val classTagTerm = ctx.addReferenceObj("classTagTerm",
+        ClassTag(SparkClassUtils.classForName(s"java.lang.$boxedJavaType")))
+      s"""${ev.value} = ($boxedJavaType[]) $array.toArray($elementTypeTerm, 
$classTagTerm);"""
+    } else {
+      s"""${ev.value} = $array.toObjectArray($elementTypeTerm);"""
+    }
+  }
+
+  override def nullSafeEval(array: Any): Any = {

Review Comment:
   Based on the currently submitted version (use `Invoke` to implement 
`ToJavaArray`), I run benchmark 3 times and the performance is as follows:
   
   ```shell
   Running benchmark: array binary search
     Running case: has foldable optimize
     Stopped after 100 iterations, 23205 ms
   
   OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0
   Apple M2
   array binary search:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
   
------------------------------------------------------------------------------------------------------------------------
   has foldable optimize                               223            232       
   18         44.8          22.3       1.0X
   
   
   Running benchmark: array binary search
     Running case: has foldable optimize
     Stopped after 100 iterations, 20181 ms
   
   OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0
   Apple M2
   array binary search:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
   
------------------------------------------------------------------------------------------------------------------------
   has foldable optimize                               195            202       
    8         51.2          19.5       1.0X
   
   Running benchmark: array binary search
     Running case: has foldable optimize
     Stopped after 100 iterations, 20819 ms
   
   OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 15.0
   Apple M2
   array binary search:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
   
------------------------------------------------------------------------------------------------------------------------
   has foldable optimize                               202            208       
    8         49.4          20.2       1.0X
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to