This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 046a62ca chore: Revise batch pull approach to more follow C Data
interface semantics (#893)
046a62ca is described below
commit 046a62ca8ec034246277bd62943ce5bd5e29db31
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Sep 3 08:17:29 2024 -0700
chore: Revise batch pull approach to more follow C Data interface semantics
(#893)
* chore: Revise batch pull approach to more follow C Data interface
semantics
* fix clippy
* Remove ExportedBatch
---
.../org/apache/comet/vector/ExportedBatch.scala | 44 ---------
.../scala/org/apache/comet/vector/NativeUtil.scala | 69 ++++++--------
native/core/src/execution/operators/scan.rs | 100 ++++++++++++---------
native/core/src/jvm_bridge/batch_iterator.rs | 5 +-
.../java/org/apache/comet/CometBatchIterator.java | 33 ++-----
.../scala/org/apache/comet/CometExecIterator.scala | 2 -
6 files changed, 98 insertions(+), 155 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala
b/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala
deleted file mode 100644
index 2e97a0dc..00000000
--- a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.comet.vector
-
-import org.apache.arrow.c.ArrowArray
-import org.apache.arrow.c.ArrowSchema
-
-/**
- * A wrapper class to hold the exported Arrow arrays and schemas.
- *
- * @param batch
- * a list containing number of rows + pairs of memory addresses in the
format of (address of
- * Arrow array, address of Arrow schema)
- * @param arrowSchemas
- * the exported Arrow schemas, needs to be deallocated after being moved by
the native executor
- * @param arrowArrays
- * the exported Arrow arrays, needs to be deallocated after being moved by
the native executor
- */
-case class ExportedBatch(
- batch: Array[Long],
- arrowSchemas: Array[ArrowSchema],
- arrowArrays: Array[ArrowArray]) {
- def close(): Unit = {
- arrowSchemas.foreach(_.close())
- arrowArrays.foreach(_.close())
- }
-}
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index eed8fd05..5149c734 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -47,50 +47,39 @@ class NativeUtil {
* an exported batches object containing an array containing number of
rows + pairs of memory
* addresses in the format of (address of Arrow array, address of Arrow
schema)
*/
- def exportBatch(batch: ColumnarBatch): ExportedBatch = {
- val exportedVectors = mutable.ArrayBuffer.empty[Long]
- exportedVectors += batch.numRows()
-
- // Run checks prior to exporting the batch
- (0 until batch.numCols()).foreach { index =>
- val c = batch.column(index)
- if (!c.isInstanceOf[CometVector]) {
- batch.close()
- throw new SparkException(
- "Comet execution only takes Arrow Arrays, but got " +
- s"${c.getClass}")
- }
- }
-
- val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema]
- val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray]
-
+ def exportBatch(
+ arrayAddrs: Array[Long],
+ schemaAddrs: Array[Long],
+ batch: ColumnarBatch): Int = {
(0 until batch.numCols()).foreach { index =>
- val cometVector = batch.column(index).asInstanceOf[CometVector]
- val valueVector = cometVector.getValueVector
-
- val provider = if (valueVector.getField.getDictionary != null) {
- cometVector.getDictionaryProvider
- } else {
- null
+ batch.column(index) match {
+ case a: CometVector =>
+ val valueVector = a.getValueVector
+
+ val provider = if (valueVector.getField.getDictionary != null) {
+ a.getDictionaryProvider
+ } else {
+ null
+ }
+
+ // The array and schema structures are allocated by native side.
+ // Don't need to deallocate them here.
+ val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
+ val arrowArray = ArrowArray.wrap(arrayAddrs(index))
+ Data.exportVector(
+ allocator,
+ getFieldVector(valueVector, "export"),
+ provider,
+ arrowArray,
+ arrowSchema)
+ case c =>
+ throw new SparkException(
+ "Comet execution only takes Arrow Arrays, but got " +
+ s"${c.getClass}")
}
-
- val arrowSchema = ArrowSchema.allocateNew(allocator)
- val arrowArray = ArrowArray.allocateNew(allocator)
- arrowSchemas += arrowSchema
- arrowArrays += arrowArray
- Data.exportVector(
- allocator,
- getFieldVector(valueVector, "export"),
- provider,
- arrowArray,
- arrowSchema)
-
- exportedVectors += arrowArray.memoryAddress()
- exportedVectors += arrowSchema.memoryAddress()
}
- ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray,
arrowArrays.toArray)
+ batch.numRows()
}
/**
diff --git a/native/core/src/execution/operators/scan.rs
b/native/core/src/execution/operators/scan.rs
index 59616efb..0816a5c1 100644
--- a/native/core/src/execution/operators/scan.rs
+++ b/native/core/src/execution/operators/scan.rs
@@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.
+use futures::Stream;
+use itertools::Itertools;
+use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
@@ -22,14 +25,6 @@ use std::{
task::{Context, Poll},
};
-use futures::Stream;
-use itertools::Itertools;
-
-use arrow::compute::{cast_with_options, CastOptions};
-use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
-use arrow_data::ArrayData;
-use arrow_schema::{DataType, Field, Schema, SchemaRef};
-
use crate::{
errors::CometError,
execution::{
@@ -38,6 +33,12 @@ use crate::{
},
jvm_bridge::{jni_call, JVMClasses},
};
+use arrow::compute::{cast_with_options, CastOptions};
+use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
+use arrow_data::ffi::FFI_ArrowArray;
+use arrow_data::ArrayData;
+use arrow_schema::ffi::FFI_ArrowSchema;
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_plan::metrics::{BaselineMetrics,
ExecutionPlanMetricsSet, MetricsSet};
use datafusion::{
execution::TaskContext,
@@ -45,10 +46,9 @@ use datafusion::{
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as
DataFusionResult};
-use jni::{
- objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
- sys::jlongArray,
-};
+use jni::objects::JValueGen;
+use jni::objects::{GlobalRef, JObject};
+use jni::sys::jsize;
/// ScanExec reads batches of data from Spark via JNI. The source of the scan
could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
@@ -86,7 +86,7 @@ impl ScanExec {
// may end up either unpacking dictionary arrays or
dictionary-encoding arrays.
// Dictionary-encoded primitive arrays are always unpacked.
let first_batch = if let Some(input_source) = input_source.as_ref() {
- ScanExec::get_next(exec_context_id, input_source.as_obj())?
+ ScanExec::get_next(exec_context_id, input_source.as_obj(),
data_types.len())?
} else {
InputBatch::EOF
};
@@ -153,6 +153,7 @@ impl ScanExec {
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
+ self.data_types.len(),
)?;
*current_batch = Some(next_batch);
}
@@ -161,7 +162,11 @@ impl ScanExec {
}
/// Invokes JNI call to get next batch.
- fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch,
CometError> {
+ fn get_next(
+ exec_context_id: i64,
+ iter: &JObject,
+ num_cols: usize,
+ ) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
// This is a unit test. We don't need to call JNI.
return Ok(InputBatch::EOF);
@@ -175,49 +180,60 @@ impl ScanExec {
}
let mut env = JVMClasses::get_env()?;
- let batch_object: JObject = unsafe {
- jni_call!(&mut env,
- comet_batch_iterator(iter).next() -> JObject)?
- };
- if batch_object.is_null() {
- return Err(CometError::from(ExecutionError::GeneralError(format!(
- "Null batch object. Plan id: {}",
- exec_context_id
- ))));
+ let mut array_addrs = Vec::with_capacity(num_cols);
+ let mut schema_addrs = Vec::with_capacity(num_cols);
+
+ for _ in 0..num_cols {
+ let arrow_array = Rc::new(FFI_ArrowArray::empty());
+ let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
+ let (array_ptr, schema_ptr) = (
+ Rc::into_raw(arrow_array) as i64,
+ Rc::into_raw(arrow_schema) as i64,
+ );
+
+ array_addrs.push(array_ptr);
+ schema_addrs.push(schema_ptr);
}
- let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw()
as jlongArray) };
+ // Prepare the java array parameters
+ let long_array_addrs = env.new_long_array(num_cols as jsize)?;
+ let long_schema_addrs = env.new_long_array(num_cols as jsize)?;
- let addresses = unsafe { env.get_array_elements(&batch_object,
ReleaseMode::NoCopyBack)? };
+ env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
+ env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;
- // First element is the number of rows.
- let num_rows = unsafe { *addresses.as_ptr() as i64 };
+ let array_obj = JObject::from(long_array_addrs);
+ let schema_obj = JObject::from(long_schema_addrs);
- if num_rows < 0 {
- return Ok(InputBatch::EOF);
- }
+ let array_obj = JValueGen::Object(array_obj.as_ref());
+ let schema_obj = JValueGen::Object(schema_obj.as_ref());
+
+ let num_rows: i32 = unsafe {
+ jni_call!(&mut env,
+ comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
+ };
- let array_num = addresses.len() - 1;
- if array_num % 2 != 0 {
- return Err(CometError::Internal(format!(
- "Invalid number of Arrow Array addresses: {}",
- array_num
- )));
+ if num_rows == -1 {
+ return Ok(InputBatch::EOF);
}
- let num_arrays = array_num / 2;
- let array_elements = unsafe { addresses.as_ptr().add(1) };
- let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_arrays);
+ let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);
- for i in 0..num_arrays {
- let array_ptr = unsafe { *(array_elements.add(i * 2)) };
- let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
+ for i in 0..num_cols {
+ let array_ptr = array_addrs[i];
+ let schema_ptr = schema_addrs[i];
let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;
// TODO: validate array input data
inputs.push(make_array(array_data));
+
+ // Drop the Arcs to avoid memory leak
+ unsafe {
+ Rc::from_raw(array_ptr as *const FFI_ArrowArray);
+ Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
+ }
}
Ok(InputBatch::new(inputs, Some(num_rows as usize)))
diff --git a/native/core/src/jvm_bridge/batch_iterator.rs
b/native/core/src/jvm_bridge/batch_iterator.rs
index 06f43a8c..4870624d 100644
--- a/native/core/src/jvm_bridge/batch_iterator.rs
+++ b/native/core/src/jvm_bridge/batch_iterator.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use jni::signature::Primitive;
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
@@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> {
Ok(CometBatchIterator {
class,
- method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?,
- method_next_ret: ReturnType::Array,
+ method_next: env.get_method_id(Self::JVM_CLASS, "next",
"([J[J)I")?,
+ method_next_ret: ReturnType::Primitive(Primitive::Int),
})
}
}
diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java
b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
index eb7506b8..accd57c2 100644
--- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java
+++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
@@ -23,7 +23,6 @@ import scala.collection.Iterator;
import org.apache.spark.sql.vectorized.ColumnarBatch;
-import org.apache.comet.vector.ExportedBatch;
import org.apache.comet.vector.NativeUtil;
/**
@@ -35,41 +34,25 @@ public class CometBatchIterator {
final Iterator<ColumnarBatch> input;
final NativeUtil nativeUtil;
- private ExportedBatch lastBatch;
-
CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
this.input = input;
this.nativeUtil = nativeUtil;
- this.lastBatch = null;
}
/**
- * Get the next batches of Arrow arrays. It will consume input iterator and
return Arrow arrays by
- * addresses. If the input iterator is done, it will return a one negative
element array
- * indicating the end of the iterator.
+ * Get the next batches of Arrow arrays.
+ *
+ * @param arrayAddrs The addresses of the ArrowArray structures.
+ * @param schemaAddrs The addresses of the ArrowSchema structures.
+ * @return the number of rows of the current batch. -1 if there is no more
batch.
*/
- public long[] next() {
- // Native side already copied the content of ArrowSchema and ArrowArray.
We should deallocate
- // the ArrowSchema and ArrowArray base structures allocated in JVM.
- if (lastBatch != null) {
- lastBatch.close();
- lastBatch = null;
- }
-
+ public int next(long[] arrayAddrs, long[] schemaAddrs) {
boolean hasBatch = input.hasNext();
if (!hasBatch) {
- return new long[] {-1};
+ return -1;
}
- lastBatch = nativeUtil.exportBatch(input.next());
- return lastBatch.batch();
- }
-
- public void close() {
- if (lastBatch != null) {
- lastBatch.close();
- lastBatch = null;
- }
+ return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
}
}
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index f1e77fb5..29eb2f0c 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -159,8 +159,6 @@ class CometExecIterator(
}
nativeLib.releasePlan(plan)
- cometBatchIterators.foreach(_.close())
-
// The allocator thoughts the exported ArrowArray and ArrowSchema
structs are not released,
// so it will report:
// Caused by: java.lang.IllegalStateException: Memory was leaked by
query.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]