This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 046a62ca chore: Revise batch pull approach to more follow C Data 
interface semantics (#893)
046a62ca is described below

commit 046a62ca8ec034246277bd62943ce5bd5e29db31
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Sep 3 08:17:29 2024 -0700

    chore: Revise batch pull approach to more follow C Data interface semantics 
(#893)
    
    * chore: Revise batch pull approach to more follow C Data interface 
semantics
    
    * fix clippy
    
    * Remove ExportedBatch
---
 .../org/apache/comet/vector/ExportedBatch.scala    |  44 ---------
 .../scala/org/apache/comet/vector/NativeUtil.scala |  69 ++++++--------
 native/core/src/execution/operators/scan.rs        | 100 ++++++++++++---------
 native/core/src/jvm_bridge/batch_iterator.rs       |   5 +-
 .../java/org/apache/comet/CometBatchIterator.java  |  33 ++-----
 .../scala/org/apache/comet/CometExecIterator.scala |   2 -
 6 files changed, 98 insertions(+), 155 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala 
b/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala
deleted file mode 100644
index 2e97a0dc..00000000
--- a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.comet.vector
-
-import org.apache.arrow.c.ArrowArray
-import org.apache.arrow.c.ArrowSchema
-
-/**
- * A wrapper class to hold the exported Arrow arrays and schemas.
- *
- * @param batch
- *   a list containing number of rows + pairs of memory addresses in the 
format of (address of
- *   Arrow array, address of Arrow schema)
- * @param arrowSchemas
- *   the exported Arrow schemas, needs to be deallocated after being moved by 
the native executor
- * @param arrowArrays
- *   the exported Arrow arrays, needs to be deallocated after being moved by 
the native executor
- */
-case class ExportedBatch(
-    batch: Array[Long],
-    arrowSchemas: Array[ArrowSchema],
-    arrowArrays: Array[ArrowArray]) {
-  def close(): Unit = {
-    arrowSchemas.foreach(_.close())
-    arrowArrays.foreach(_.close())
-  }
-}
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala 
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index eed8fd05..5149c734 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -47,50 +47,39 @@ class NativeUtil {
    *   an exported batches object containing an array containing number of 
rows + pairs of memory
    *   addresses in the format of (address of Arrow array, address of Arrow 
schema)
    */
-  def exportBatch(batch: ColumnarBatch): ExportedBatch = {
-    val exportedVectors = mutable.ArrayBuffer.empty[Long]
-    exportedVectors += batch.numRows()
-
-    // Run checks prior to exporting the batch
-    (0 until batch.numCols()).foreach { index =>
-      val c = batch.column(index)
-      if (!c.isInstanceOf[CometVector]) {
-        batch.close()
-        throw new SparkException(
-          "Comet execution only takes Arrow Arrays, but got " +
-            s"${c.getClass}")
-      }
-    }
-
-    val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema]
-    val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray]
-
+  def exportBatch(
+      arrayAddrs: Array[Long],
+      schemaAddrs: Array[Long],
+      batch: ColumnarBatch): Int = {
     (0 until batch.numCols()).foreach { index =>
-      val cometVector = batch.column(index).asInstanceOf[CometVector]
-      val valueVector = cometVector.getValueVector
-
-      val provider = if (valueVector.getField.getDictionary != null) {
-        cometVector.getDictionaryProvider
-      } else {
-        null
+      batch.column(index) match {
+        case a: CometVector =>
+          val valueVector = a.getValueVector
+
+          val provider = if (valueVector.getField.getDictionary != null) {
+            a.getDictionaryProvider
+          } else {
+            null
+          }
+
+          // The array and schema structures are allocated by native side.
+          // Don't need to deallocate them here.
+          val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
+          val arrowArray = ArrowArray.wrap(arrayAddrs(index))
+          Data.exportVector(
+            allocator,
+            getFieldVector(valueVector, "export"),
+            provider,
+            arrowArray,
+            arrowSchema)
+        case c =>
+          throw new SparkException(
+            "Comet execution only takes Arrow Arrays, but got " +
+              s"${c.getClass}")
       }
-
-      val arrowSchema = ArrowSchema.allocateNew(allocator)
-      val arrowArray = ArrowArray.allocateNew(allocator)
-      arrowSchemas += arrowSchema
-      arrowArrays += arrowArray
-      Data.exportVector(
-        allocator,
-        getFieldVector(valueVector, "export"),
-        provider,
-        arrowArray,
-        arrowSchema)
-
-      exportedVectors += arrowArray.memoryAddress()
-      exportedVectors += arrowSchema.memoryAddress()
     }
 
-    ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray, 
arrowArrays.toArray)
+    batch.numRows()
   }
 
   /**
diff --git a/native/core/src/execution/operators/scan.rs 
b/native/core/src/execution/operators/scan.rs
index 59616efb..0816a5c1 100644
--- a/native/core/src/execution/operators/scan.rs
+++ b/native/core/src/execution/operators/scan.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use futures::Stream;
+use itertools::Itertools;
+use std::rc::Rc;
 use std::{
     any::Any,
     pin::Pin,
@@ -22,14 +25,6 @@ use std::{
     task::{Context, Poll},
 };
 
-use futures::Stream;
-use itertools::Itertools;
-
-use arrow::compute::{cast_with_options, CastOptions};
-use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
-use arrow_data::ArrayData;
-use arrow_schema::{DataType, Field, Schema, SchemaRef};
-
 use crate::{
     errors::CometError,
     execution::{
@@ -38,6 +33,12 @@ use crate::{
     },
     jvm_bridge::{jni_call, JVMClasses},
 };
+use arrow::compute::{cast_with_options, CastOptions};
+use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
+use arrow_data::ffi::FFI_ArrowArray;
+use arrow_data::ArrayData;
+use arrow_schema::ffi::FFI_ArrowSchema;
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
 use datafusion::physical_plan::metrics::{BaselineMetrics, 
ExecutionPlanMetricsSet, MetricsSet};
 use datafusion::{
     execution::TaskContext,
@@ -45,10 +46,9 @@ use datafusion::{
     physical_plan::{ExecutionPlan, *},
 };
 use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as 
DataFusionResult};
-use jni::{
-    objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
-    sys::jlongArray,
-};
+use jni::objects::JValueGen;
+use jni::objects::{GlobalRef, JObject};
+use jni::sys::jsize;
 
 /// ScanExec reads batches of data from Spark via JNI. The source of the scan 
could be a file
 /// scan or the result of reading a broadcast or shuffle exchange.
@@ -86,7 +86,7 @@ impl ScanExec {
         // may end up either unpacking dictionary arrays or 
dictionary-encoding arrays.
         // Dictionary-encoded primitive arrays are always unpacked.
         let first_batch = if let Some(input_source) = input_source.as_ref() {
-            ScanExec::get_next(exec_context_id, input_source.as_obj())?
+            ScanExec::get_next(exec_context_id, input_source.as_obj(), 
data_types.len())?
         } else {
             InputBatch::EOF
         };
@@ -153,6 +153,7 @@ impl ScanExec {
             let next_batch = ScanExec::get_next(
                 self.exec_context_id,
                 self.input_source.as_ref().unwrap().as_obj(),
+                self.data_types.len(),
             )?;
             *current_batch = Some(next_batch);
         }
@@ -161,7 +162,11 @@ impl ScanExec {
     }
 
     /// Invokes JNI call to get next batch.
-    fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch, 
CometError> {
+    fn get_next(
+        exec_context_id: i64,
+        iter: &JObject,
+        num_cols: usize,
+    ) -> Result<InputBatch, CometError> {
         if exec_context_id == TEST_EXEC_CONTEXT_ID {
             // This is a unit test. We don't need to call JNI.
             return Ok(InputBatch::EOF);
@@ -175,49 +180,60 @@ impl ScanExec {
         }
 
         let mut env = JVMClasses::get_env()?;
-        let batch_object: JObject = unsafe {
-            jni_call!(&mut env,
-            comet_batch_iterator(iter).next() -> JObject)?
-        };
 
-        if batch_object.is_null() {
-            return Err(CometError::from(ExecutionError::GeneralError(format!(
-                "Null batch object. Plan id: {}",
-                exec_context_id
-            ))));
+        let mut array_addrs = Vec::with_capacity(num_cols);
+        let mut schema_addrs = Vec::with_capacity(num_cols);
+
+        for _ in 0..num_cols {
+            let arrow_array = Rc::new(FFI_ArrowArray::empty());
+            let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
+            let (array_ptr, schema_ptr) = (
+                Rc::into_raw(arrow_array) as i64,
+                Rc::into_raw(arrow_schema) as i64,
+            );
+
+            array_addrs.push(array_ptr);
+            schema_addrs.push(schema_ptr);
         }
 
-        let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() 
as jlongArray) };
+        // Prepare the java array parameters
+        let long_array_addrs = env.new_long_array(num_cols as jsize)?;
+        let long_schema_addrs = env.new_long_array(num_cols as jsize)?;
 
-        let addresses = unsafe { env.get_array_elements(&batch_object, 
ReleaseMode::NoCopyBack)? };
+        env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
+        env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;
 
-        // First element is the number of rows.
-        let num_rows = unsafe { *addresses.as_ptr() as i64 };
+        let array_obj = JObject::from(long_array_addrs);
+        let schema_obj = JObject::from(long_schema_addrs);
 
-        if num_rows < 0 {
-            return Ok(InputBatch::EOF);
-        }
+        let array_obj = JValueGen::Object(array_obj.as_ref());
+        let schema_obj = JValueGen::Object(schema_obj.as_ref());
+
+        let num_rows: i32 = unsafe {
+            jni_call!(&mut env,
+        comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
+        };
 
-        let array_num = addresses.len() - 1;
-        if array_num % 2 != 0 {
-            return Err(CometError::Internal(format!(
-                "Invalid number of Arrow Array addresses: {}",
-                array_num
-            )));
+        if num_rows == -1 {
+            return Ok(InputBatch::EOF);
         }
 
-        let num_arrays = array_num / 2;
-        let array_elements = unsafe { addresses.as_ptr().add(1) };
-        let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_arrays);
+        let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);
 
-        for i in 0..num_arrays {
-            let array_ptr = unsafe { *(array_elements.add(i * 2)) };
-            let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
+        for i in 0..num_cols {
+            let array_ptr = array_addrs[i];
+            let schema_ptr = schema_addrs[i];
             let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;
 
             // TODO: validate array input data
 
             inputs.push(make_array(array_data));
+
+            // Drop the Arcs to avoid memory leak
+            unsafe {
+                Rc::from_raw(array_ptr as *const FFI_ArrowArray);
+                Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
+            }
         }
 
         Ok(InputBatch::new(inputs, Some(num_rows as usize)))
diff --git a/native/core/src/jvm_bridge/batch_iterator.rs 
b/native/core/src/jvm_bridge/batch_iterator.rs
index 06f43a8c..4870624d 100644
--- a/native/core/src/jvm_bridge/batch_iterator.rs
+++ b/native/core/src/jvm_bridge/batch_iterator.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use jni::signature::Primitive;
 use jni::{
     errors::Result as JniResult,
     objects::{JClass, JMethodID},
@@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> {
 
         Ok(CometBatchIterator {
             class,
-            method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?,
-            method_next_ret: ReturnType::Array,
+            method_next: env.get_method_id(Self::JVM_CLASS, "next", 
"([J[J)I")?,
+            method_next_ret: ReturnType::Primitive(Primitive::Int),
         })
     }
 }
diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java 
b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
index eb7506b8..accd57c2 100644
--- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java
+++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
@@ -23,7 +23,6 @@ import scala.collection.Iterator;
 
 import org.apache.spark.sql.vectorized.ColumnarBatch;
 
-import org.apache.comet.vector.ExportedBatch;
 import org.apache.comet.vector.NativeUtil;
 
 /**
@@ -35,41 +34,25 @@ public class CometBatchIterator {
   final Iterator<ColumnarBatch> input;
   final NativeUtil nativeUtil;
 
-  private ExportedBatch lastBatch;
-
   CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
     this.input = input;
     this.nativeUtil = nativeUtil;
-    this.lastBatch = null;
   }
 
   /**
-   * Get the next batches of Arrow arrays. It will consume input iterator and 
return Arrow arrays by
-   * addresses. If the input iterator is done, it will return a one negative 
element array
-   * indicating the end of the iterator.
+   * Get the next batches of Arrow arrays.
+   *
+   * @param arrayAddrs The addresses of the ArrowArray structures.
+   * @param schemaAddrs The addresses of the ArrowSchema structures.
+   * @return the number of rows of the current batch. -1 if there is no more 
batch.
    */
-  public long[] next() {
-    // Native side already copied the content of ArrowSchema and ArrowArray. 
We should deallocate
-    // the ArrowSchema and ArrowArray base structures allocated in JVM.
-    if (lastBatch != null) {
-      lastBatch.close();
-      lastBatch = null;
-    }
-
+  public int next(long[] arrayAddrs, long[] schemaAddrs) {
     boolean hasBatch = input.hasNext();
 
     if (!hasBatch) {
-      return new long[] {-1};
+      return -1;
     }
 
-    lastBatch = nativeUtil.exportBatch(input.next());
-    return lastBatch.batch();
-  }
-
-  public void close() {
-    if (lastBatch != null) {
-      lastBatch.close();
-      lastBatch = null;
-    }
+    return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
   }
 }
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala 
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index f1e77fb5..29eb2f0c 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -159,8 +159,6 @@ class CometExecIterator(
       }
       nativeLib.releasePlan(plan)
 
-      cometBatchIterators.foreach(_.close())
-
       // The allocator thoughts the exported ArrowArray and ArrowSchema 
structs are not released,
       // so it will report:
       // Caused by: java.lang.IllegalStateException: Memory was leaked by 
query.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to