This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch dev-v6.0.0-decimal-cast
in repository https://gitbox.apache.org/repos/asf/auron.git

commit bcbe6736a7ade7185ef6896eba3aa511f5464a23
Author: zhangli20 <[email protected]>
AuthorDate: Thu Jan 22 15:29:41 2026 +0800

    optimize UDAF wrapper
---
 native-engine/blaze-jni-bridge/src/jni_bridge.rs   |  24 ++--
 .../datafusion-ext-plans/src/agg/count.rs          | 132 ++++++++++++++++----
 .../src/agg/spark_udaf_wrapper.rs                  |  65 +++-------
 .../spark/sql/blaze/SparkUDAFWrapperContext.scala  | 137 +++++++++------------
 4 files changed, 199 insertions(+), 159 deletions(-)

diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs 
b/native-engine/blaze-jni-bridge/src/jni_bridge.rs
index 44b53508..f20cd17a 100644
--- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs
+++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs
@@ -1228,10 +1228,10 @@ pub struct SparkUDAFWrapperContext<'a> {
     pub method_merge_ret: ReturnType,
     pub method_eval: JMethodID,
     pub method_eval_ret: ReturnType,
-    pub method_serializeRows: JMethodID,
-    pub method_serializeRows_ret: ReturnType,
-    pub method_deserializeRows: JMethodID,
-    pub method_deserializeRows_ret: ReturnType,
+    pub method_exportRows: JMethodID,
+    pub method_exportRows_ret: ReturnType,
+    pub method_importRows: JMethodID,
+    pub method_importRows_ret: ReturnType,
     pub method_spill: JMethodID,
     pub method_spill_ret: ReturnType,
     pub method_unspill: JMethodID,
@@ -1281,18 +1281,18 @@ impl<'a> SparkUDAFWrapperContext<'a> {
                 "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V",
             )?,
             method_eval_ret: ReturnType::Primitive(Primitive::Void),
-            method_serializeRows: env.get_method_id(
+            method_exportRows: env.get_method_id(
                 class,
-                "serializeRows",
-                "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[I)[B",
+                "exportRows",
+                "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V",
             )?,
-            method_serializeRows_ret: ReturnType::Array,
-            method_deserializeRows: env.get_method_id(
+            method_exportRows_ret: ReturnType::Array,
+            method_importRows: env.get_method_id(
                 class,
-                "deserializeRows",
-                
"(Ljava/nio/ByteBuffer;)Lorg/apache/spark/sql/blaze/BufferRowsColumn;",
+                "importRows",
+                "(J)Lorg/apache/spark/sql/blaze/BufferRowsColumn;",
             )?,
-            method_deserializeRows_ret: ReturnType::Object,
+            method_importRows_ret: ReturnType::Object,
             method_spill: env.get_method_id(
                 class,
                 "spill",
diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs 
b/native-engine/datafusion-ext-plans/src/agg/count.rs
index 61508c77..ad4b71cf 100644
--- a/native-engine/datafusion-ext-plans/src/agg/count.rs
+++ b/native-engine/datafusion-ext-plans/src/agg/count.rs
@@ -20,14 +20,18 @@ use std::{
 
 use arrow::{array::*, datatypes::*};
 use datafusion::{common::Result, physical_expr::PhysicalExprRef};
-use datafusion_ext_commons::downcast_any;
+use datafusion_ext_commons::{
+    downcast_any,
+    io::{read_len, write_len},
+};
 
 use crate::{
     agg::{
-        acc::{AccColumn, AccColumnRef, AccPrimColumn},
+        acc::{AccColumn, AccColumnRef},
         agg::{Agg, IdxSelection},
     },
-    idx_for_zipped,
+    idx_for, idx_for_zipped,
+    memmgr::spill::{SpillCompressedReader, SpillCompressedWriter},
 };
 
 pub struct AggCount {
@@ -76,11 +80,9 @@ impl Agg for AggCount {
     }
 
     fn create_acc_column(&self, num_rows: usize) -> Box<dyn AccColumn> {
-        Box::new(AccPrimColumn::<i64>::new(num_rows, DataType::Int64))
-    }
-
-    fn acc_array_data_types(&self) -> &[DataType] {
-        &[DataType::Int64]
+        Box::new(AccCountColumn {
+            values: vec![0; num_rows],
+        })
     }
 
     fn partial_update(
@@ -90,15 +92,32 @@ impl Agg for AggCount {
         partial_args: &[ArrayRef],
         partial_arg_idx: IdxSelection<'_>,
     ) -> Result<()> {
-        let accs = downcast_any!(accs, mut AccPrimColumn<i64>)?;
+        let accs = downcast_any!(accs, mut AccCountColumn)?;
         accs.ensure_size(acc_idx);
 
-        idx_for_zipped! {
-            ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
-                let add = partial_args
-                    .iter()
-                    .all(|arg| arg.is_valid(partial_arg_idx)) as i64;
-                accs.set_value(acc_idx, Some(accs.value(acc_idx).unwrap_or(0) 
+ add));
+        if partial_args.is_empty() {
+            idx_for_zipped! {
+                ((acc_idx, _partial_arg_idx) in (acc_idx, partial_arg_idx)) => 
{
+                    if acc_idx >= accs.values.len() {
+                        accs.values.push(1);
+                    } else {
+                        accs.values[acc_idx] += 1;
+                    }
+                }
+            }
+        } else {
+            idx_for_zipped! {
+                ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
+                    let add = partial_args
+                        .iter()
+                        .all(|arg| arg.is_valid(partial_arg_idx)) as i64;
+
+                    if acc_idx >= accs.values.len() {
+                        accs.values.push(add);
+                    } else {
+                        accs.values[acc_idx] += add;
+                    }
+                }
             }
         }
         Ok(())
@@ -111,19 +130,17 @@ impl Agg for AggCount {
         merging_accs: &mut AccColumnRef,
         merging_acc_idx: IdxSelection<'_>,
     ) -> Result<()> {
-        let accs = downcast_any!(accs, mut AccPrimColumn<i64>)?;
-        let merging_accs = downcast_any!(merging_accs, mut 
AccPrimColumn<i64>)?;
+        let accs = downcast_any!(accs, mut AccCountColumn)?;
+        let merging_accs = downcast_any!(merging_accs, mut AccCountColumn)?;
         accs.ensure_size(acc_idx);
 
         idx_for_zipped! {
             ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
-                let v = match (accs.value(acc_idx), 
merging_accs.value(merging_acc_idx)) {
-                    (Some(a), Some(b)) => Some(a + b),
-                    (Some(a), _) => Some(a),
-                    (_, Some(b)) => Some(b),
-                    _ => Some(0),
-                };
-                accs.set_value(acc_idx, v);
+                if acc_idx < accs.values.len() {
+                    accs.values[acc_idx] += 
merging_accs.values[merging_acc_idx];
+                } else {
+                    accs.values.push(merging_accs.values[merging_acc_idx]);
+                }
             }
         }
         Ok(())
@@ -132,4 +149,71 @@ impl Agg for AggCount {
     fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) 
-> Result<ArrayRef> {
         Ok(accs.freeze_to_arrays(acc_idx)?[0].clone())
     }
+
+    fn acc_array_data_types(&self) -> &[DataType] {
+        &[DataType::Int64]
+    }
+}
+
+pub struct AccCountColumn {
+    pub values: Vec<i64>,
+}
+
+impl AccColumn for AccCountColumn {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn as_any_mut(&mut self) -> &mut dyn Any {
+        self
+    }
+
+    fn resize(&mut self, num_accs: usize) {
+        self.values.resize(num_accs, 0);
+    }
+
+    fn shrink_to_fit(&mut self) {
+        self.values.shrink_to_fit();
+    }
+
+    fn num_records(&self) -> usize {
+        self.values.len()
+    }
+
+    fn mem_used(&self) -> usize {
+        self.values.capacity() * 2 * size_of::<i64>()
+    }
+
+    fn freeze_to_arrays(&mut self, idx: IdxSelection<'_>) -> 
Result<Vec<ArrayRef>> {
+        let mut values = Vec::with_capacity(idx.len());
+        idx_for! {
+            (idx in idx) => {
+                values.push(self.values[idx]);
+            }
+        }
+        Ok(vec![Arc::new(Int64Array::from(values))])
+    }
+
+    fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> {
+        let array = downcast_any!(arrays[0], Int64Array)?;
+        self.values = array.iter().map(|v| v.unwrap_or(0)).collect();
+        Ok(())
+    }
+
+    fn spill(&mut self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) 
-> Result<()> {
+        idx_for! {
+            (idx in idx) => {
+                write_len(self.values[idx] as usize, w)?;
+            }
+        }
+        Ok(())
+    }
+
+    fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> 
Result<()> {
+        assert_eq!(self.num_records(), 0, "expect empty AccColumn");
+        for _ in 0..num_rows {
+            self.values.push(read_len(r)? as i64);
+        }
+        Ok(())
+    }
 }
diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs 
b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs
index c891285f..1dada28f 100644
--- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs
+++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs
@@ -15,21 +15,21 @@
 use std::{
     any::Any,
     fmt::{Debug, Display, Formatter},
-    io::{Cursor, Read, Write},
     sync::Arc,
 };
 
 use arrow::{
     array::{
-        Array, ArrayAccessor, ArrayRef, BinaryArray, BinaryBuilder, 
StructArray, as_struct_array,
+        Array, ArrayRef, StructArray, as_struct_array,
         make_array,
     },
     datatypes::{DataType, Field, Schema, SchemaRef},
     ffi::{FFI_ArrowArray, FFI_ArrowSchema, from_ffi},
     record_batch::{RecordBatch, RecordBatchOptions},
 };
+use arrow::ffi::from_ffi_and_data_type;
 use blaze_jni_bridge::{
-    jni_bridge::LocalRef, jni_call, jni_get_byte_array_len, 
jni_get_byte_array_region,
+    jni_bridge::LocalRef, jni_call,
     jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, 
jni_new_prim_array,
 };
 use datafusion::{
@@ -37,7 +37,7 @@ use datafusion::{
     physical_expr::PhysicalExprRef,
 };
 use datafusion_ext_commons::{
-    UninitializedInit, downcast_any,
+    downcast_any,
     io::{read_len, write_len},
 };
 use jni::objects::{GlobalRef, JObject};
@@ -300,30 +300,22 @@ impl AccUDAFBufferRowsColumn {
         idx: IdxSelection<'_>,
         cache: &OnceCell<LocalRef>,
     ) -> Result<ArrayRef> {
+
+        let mut ffi_exported_rows = FFI_ArrowArray::empty();
         let idx_array =
             cache.get_or_try_init(move || jni_new_prim_array!(int, 
&idx.to_int32_vec()[..]))?;
-        let serialized = jni_call!(
-            SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows(
+        jni_call!(
+            SparkUDAFWrapperContext(self.jcontext.as_obj()).exportRows(
                 self.obj.as_obj(),
                 idx_array.as_obj(),
-            ) -> JObject)?;
-        let serialized_len = jni_get_byte_array_len!(serialized.as_obj())?;
-        let mut serialized_bytes = Vec::uninitialized_init(serialized_len);
-        jni_get_byte_array_region!(serialized.as_obj(), 0, &mut 
serialized_bytes[..])?;
-
-        // UnsafeRow is serialized with big-endian i32 length prefix
-        let mut serialized_pos = 0;
-        let mut binary_builder = BinaryBuilder::with_capacity(idx.len(), 0);
-        for i in 0..idx.len() {
-            let mut bytes_len_buf = [0u8; 4];
-            
bytes_len_buf.copy_from_slice(&serialized_bytes[serialized_pos..][..4]);
-            let bytes_len = i32::from_be_bytes(bytes_len_buf) as usize;
-            serialized_pos += 4;
-
-            
binary_builder.append_value(&serialized_bytes[serialized_pos..][..bytes_len]);
-            serialized_pos += bytes_len;
-        }
-        Ok(Arc::new(binary_builder.finish()))
+                &mut ffi_exported_rows as *mut FFI_ArrowArray as i64,
+            ) -> ())?;
+        let exported_rows_data = unsafe {
+            // safety: import output binary array from 
SparkUDAFWrapperContext.exportedRows()
+            from_ffi_and_data_type(ffi_exported_rows, DataType::Binary)?
+        };
+        let exported_rows = make_array(exported_rows_data);
+        Ok(exported_rows)
     }
 
     pub fn spill_with_indices_cache(
@@ -404,29 +396,12 @@ impl AccColumn for AccUDAFBufferRowsColumn {
 
     fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> {
         assert_eq!(self.num_records(), 0, "expect empty AccColumn");
-        let array = downcast_any!(arrays[0], BinaryArray)?;
-
-        let mut cursors = vec![];
-        for i in 0..array.len() {
-            cursors.push(Cursor::new(array.value(i)));
-        }
-
-        let mut data = vec![];
-        for (i, cursor) in cursors.iter_mut().enumerate() {
-            let bytes_len = array.value(i).len();
-            data.write_all((bytes_len as i32).to_be_bytes().as_ref())?;
-            std::io::copy(&mut cursor.take(bytes_len as u64), &mut data)?;
-        }
-
-        let data_buffer = jni_new_direct_byte_buffer!(data)?;
+        let num_rows = arrays[0].len();
+        let ffi_imported_rows = FFI_ArrowArray::new(&arrays[0].to_data());
         let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj())
-            .deserializeRows(data_buffer.as_obj()) -> JObject)?;
+            .importRows(&ffi_imported_rows as *const FFI_ArrowArray as i64) -> 
JObject)?;
         self.obj = jni_new_global_ref!(rows.as_obj())?;
-        assert_eq!(
-            self.num_records(),
-            cursors.len(),
-            "unfreeze rows count mismatch"
-        );
+        assert_eq!(self.num_records(), num_rows, "unfreeze rows count 
mismatch");
         Ok(())
     }
 
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala
index ee11feba..4a44ad48 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala
@@ -22,15 +22,14 @@ import java.io.EOFException
 import java.io.InputStream
 import java.io.OutputStream
 import java.nio.ByteBuffer
-
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
-
 import org.apache.arrow.c.ArrowArray
 import org.apache.arrow.c.Data
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.{FieldVector, VarBinaryVector, VectorSchemaRoot}
 import org.apache.arrow.vector.dictionary.DictionaryProvider
 import 
org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider
+import org.apache.arrow.vector.types.pojo.Field
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.blaze.memory.OnHeapSpillManager
@@ -167,12 +166,12 @@ case class SparkUDAFWrapperContext[B](serialized: 
ByteBuffer) extends Logging {
     }
   }
 
-  def serializeRows(rows: BufferRowsColumn[B], indices: Array[Int]): 
Array[Byte] = {
-    aggEvaluator.get.serializeRows(rows, indices.iterator)
+  def exportRows(rows: BufferRowsColumn[B], indices: Array[Int], 
outputArrowBinaryArrayPtr: Long): Unit = {
+    aggEvaluator.get.exportRows(rows, indices.iterator, 
outputArrowBinaryArrayPtr)
   }
 
-  def deserializeRows(dataBuffer: ByteBuffer): BufferRowsColumn[B] = {
-    aggEvaluator.get.deserializeRows(dataBuffer)
+  def importRows(inputArrowBinaryArrayPtr: Long): BufferRowsColumn[B] = {
+    aggEvaluator.get.importRows(inputArrowBinaryArrayPtr)
   }
 
   def spill(
@@ -205,14 +204,8 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] 
extends Logging {
 
   def createEmptyColumn(): R
 
-  def serializeRows(
-      rows: R,
-      indices: Iterator[Int],
-      streamWrapper: OutputStream => OutputStream = { s => s }): Array[Byte]
-
-  def deserializeRows(
-      dataBuffer: ByteBuffer,
-      streamWrapper: InputStream => InputStream = { s => s }): R
+  def exportRows(rows: R, indices: Iterator[Int], outputArrowBinaryArrayPtr: 
Long): Unit
+  def importRows(inputArrowBinaryArrayPtr: Long): BufferRowsColumn[B]
 
   def spill(
       memTracker: SparkUDAFMemTracker,
@@ -222,7 +215,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] 
extends Logging {
     val hsm = OnHeapSpillManager.current
     val spillId = memTracker.getSpill(spillIdx)
     val byteBuffer =
-      ByteBuffer.wrap(serializeRows(rows, indices, 
spillCodec.compressedOutputStream))
+      ByteBuffer.wrap(exportRows(rows, indices, 
spillCodec.compressedOutputStream))
     val spillBlockSize = byteBuffer.limit()
     hsm.writeSpill(spillId, byteBuffer)
     spillBlockSize
@@ -238,7 +231,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] 
extends Logging {
     val readSize = hsm.readSpill(spillId, byteBuffer).toLong
     assert(readSize == spillBlockSize)
     byteBuffer.flip()
-    deserializeRows(byteBuffer, spillCodec.compressedInputStream)
+    importRows(byteBuffer, spillCodec.compressedInputStream)
   }
 }
 
@@ -267,42 +260,45 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, 
inputAttributes: Seq[A
     DeclarativeAggRowsColumn(this, ArrayBuffer())
   }
 
-  override def serializeRows(
+  override def exportRows(
       rows: DeclarativeAggRowsColumn,
       indices: Iterator[Int],
-      streamWrapper: OutputStream => OutputStream): Array[Byte] = {
+      outputArrowBinaryArrayPtr: Long): Unit = {
 
-    val numFields = agg.aggBufferSchema.length
-    val outputDataStream = new ByteArrayOutputStream()
-    val wrappedStream = streamWrapper(outputDataStream)
-    val serializer = new UnsafeRowSerializer(numFields).newInstance()
+    Using.resource(new VarBinaryVector("output", ROOT_ALLOCATOR)) { 
binaryVector =>
+      val rowDataStream = new ByteArrayOutputStream()
+      val rowDataBuffer = new Array[Byte](1024)
 
-    Using(serializer.serializeStream(wrappedStream)) { ser =>
-      for (i <- indices) {
-        ser.writeValue(rows.rows(i))
-        rows.rows(i) = releasedRow
+      for ((rowIdx, outputRowIdx) <- indices.zipWithIndex) {
+        rows.rows(rowIdx).writeToStream(rowDataStream, rowDataBuffer)
+        rows.rows(rowIdx) = releasedRow
+        binaryVector.setSafe(outputRowIdx, rowDataStream.toByteArray)
+        rowDataStream.reset()
+      }
+
+      Using.resource(ArrowArray.wrap(outputArrowBinaryArrayPtr)) { outputArray 
=>
+        Data.exportVector(ROOT_ALLOCATOR, binaryVector, new 
MapDictionaryProvider, outputArray)
       }
     }
-    wrappedStream.close()
-    outputDataStream.toByteArray
   }
 
-  override def deserializeRows(
-      dataBuffer: ByteBuffer,
-      streamWrapper: InputStream => InputStream): DeclarativeAggRowsColumn = {
-    val numFields = agg.aggBufferSchema.length
-    val deserializer = new UnsafeRowSerializer(numFields).newInstance()
-    val inputDataStream = new ByteBufferInputStream(dataBuffer)
-    val wrappedStream = streamWrapper(inputDataStream)
-    val rows = new ArrayBuffer[UnsafeRow]()
-
-    Using.resource(deserializer.deserializeStream(wrappedStream)) { deser =>
-      for (row <- 
deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) {
+  override def importRows(inputArrowBinaryArrayPtr: Long): 
DeclarativeAggRowsColumn = {
+    Using.resource(new VarBinaryVector("input", ROOT_ALLOCATOR)) { 
binaryVector =>
+      Using.resource(ArrowArray.wrap(inputArrowBinaryArrayPtr)) { inputArray =>
+        Data.importIntoVector(ROOT_ALLOCATOR, inputArray, binaryVector, new 
MapDictionaryProvider)
+      }
+      val numRows = binaryVector.getValueCount
+      val numFields = agg.aggBufferSchema.length
+      val rows = new ArrayBuffer[UnsafeRow]()
+
+      for (rowIdx <- 0 until numRows) {
+        val row = new UnsafeRow(numFields)
+        val rowData = binaryVector.get(rowIdx)
+        row.pointTo(rowData, rowData.length)
         rows.append(row)
       }
+      DeclarativeAggRowsColumn(this, rows)
     }
-    wrappedStream.close()
-    DeclarativeAggRowsColumn(this, rows)
   }
 }
 
@@ -378,51 +374,36 @@ class TypedImperativeEvaluator[B](val agg: 
TypedImperativeAggregate[B])
     new TypedImperativeAggRowsColumn[B](this, ArrayBuffer())
   }
 
-  override def serializeRows(
+  override def exportRows(
       rows: TypedImperativeAggRowsColumn[B],
       indices: Iterator[Int],
-      streamWrapper: OutputStream => OutputStream): Array[Byte] = {
+      outputArrowBinaryArrayPtr: Long): Unit = {
 
-    val outputStream = new ByteArrayOutputStream()
-    val wrappedStream = streamWrapper(outputStream)
-    val dataOut = new DataOutputStream(wrappedStream)
+    Using.resource(new VarBinaryVector("output", ROOT_ALLOCATOR)) { 
binaryVector =>
+      for ((rowIdx, outputRowIdx) <- indices.zipWithIndex) {
+        binaryVector.setSafe(outputRowIdx, rows.serializedRow(rowIdx))
+        rows.rows(rowIdx) = releasedRow
+      }
 
-    for (i <- indices) {
-      val bytes = rows.serializedRow(i)
-      dataOut.writeInt(bytes.length)
-      dataOut.write(bytes)
-      rows.rows(i) = releasedRow
+      Using.resource(ArrowArray.wrap(outputArrowBinaryArrayPtr)) { outputArray 
=>
+        Data.exportVector(ROOT_ALLOCATOR, binaryVector, new 
MapDictionaryProvider, outputArray)
+      }
     }
-    dataOut.close()
-    outputStream.toByteArray
-  }
-
-  override def deserializeRows(
-      dataBuffer: ByteBuffer,
-      streamWrapper: InputStream => InputStream): 
TypedImperativeAggRowsColumn[B] = {
-    val rows = ArrayBuffer[RowType]()
-    val inputStream = new ByteBufferInputStream(dataBuffer)
-    val wrappedStream = streamWrapper(inputStream)
-    val dataIn = new DataInputStream(wrappedStream)
-    var finished = false
-
-    while (!finished) {
-      var length = -1
-      try {
-        length = dataIn.readInt()
-      } catch {
-        case _: EOFException =>
-          finished = true
+
+  override def importRows(inputArrowBinaryArrayPtr: Long): 
TypedImperativeAggRowsColumn[B] = {
+    Using.resource(new VarBinaryVector("input", ROOT_ALLOCATOR)) { 
binaryVector =>
+      Using.resource(ArrowArray.wrap(inputArrowBinaryArrayPtr)) { inputArray =>
+        Data.importIntoVector(ROOT_ALLOCATOR, inputArray, binaryVector, new 
MapDictionaryProvider)
       }
+      val numRows = binaryVector.getValueCount
+      val rows = ArrayBuffer[RowType]()
 
-      if (!finished) {
-        val bytes = new Array[Byte](length)
-        dataIn.read(bytes)
-        rows.append(SerializedRowType(bytes))
+      for (rowIdx <- 0 until numRows) {
+        val rowData = binaryVector.get(rowIdx)
+        rows.append(SerializedRowType(rowData))
       }
+      TypedImperativeAggRowsColumn(this, rows)
     }
-    dataIn.close()
-    TypedImperativeAggRowsColumn(this, rows)
   }
 }
 

Reply via email to