alamb commented on code in PR #16734:
URL: https://github.com/apache/datafusion/pull/16734#discussion_r2207881274
##########
datafusion/physical-plan/src/stream.rs:
##########
@@ -522,6 +524,139 @@ impl Stream for ObservedStream {
}
}
+pin_project! {
+ /// Stream wrapper that splits large [`RecordBatch`]es into smaller
batches.
+ ///
+ /// This ensures upstream operators receive batches no larger than
+ /// `batch_size`, which can improve parallelism when data sources
+ /// generate very large batches.
+ ///
+ /// # Fields
+ ///
+ /// - `current_batch`: The batch currently being split, if any
+ /// - `offset`: Index of the next row to split from `current_batch`.
+ /// This tracks our position within the current batch being split.
+ ///
+ /// # Invariants
+ ///
+ /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch`
is `Some`
+ /// - When `current_batch` is `None`, `offset` is always 0
+ /// - `batch_size` is always > 0
+pub struct BatchSplitStream {
+ #[pin]
+ input: SendableRecordBatchStream,
+ schema: SchemaRef,
+ batch_size: usize,
+ metrics: SplitMetrics,
+ current_batch: Option<RecordBatch>,
+ offset: usize,
+ }
+}
+
+impl BatchSplitStream {
+ /// Create a new [`BatchSplitStream`]
+ pub fn new(
+ input: SendableRecordBatchStream,
+ batch_size: usize,
+ metrics: SplitMetrics,
+ ) -> Self {
+ let schema = input.schema();
+ Self {
+ input,
+ schema,
+ batch_size,
+ metrics,
+ current_batch: None,
+ offset: 0,
+ }
+ }
+
+ /// Attempt to produce the next sliced batch from the current batch.
+ ///
+ /// Returns `Some(batch)` if a slice was produced, `None` if the current
batch
+ /// is exhausted and we need to poll upstream for more data.
+ fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
+ let batch = self.current_batch.take()?;
+
+ // Assert slice boundary safety - offset should never exceed batch size
+ debug_assert!(
+ self.offset <= batch.num_rows(),
+ "Offset {} exceeds batch size {}",
+ self.offset,
+ batch.num_rows()
+ );
+
+ let remaining = batch.num_rows() - self.offset;
+ let to_take = remaining.min(self.batch_size);
+ let out = batch.slice(self.offset, to_take);
+
+ self.metrics.batches_splitted.add(1);
+ self.offset += to_take;
+ if self.offset < batch.num_rows() {
+ // More data remains in this batch, store it back
+ self.current_batch = Some(batch);
+ } else {
+ // Batch is exhausted, reset offset
+ // Note: current_batch is already None since we took it at the
start
+ self.offset = 0;
+ }
+ Some(Ok(out))
+ }
+
+ /// Poll the upstream input for the next batch.
+ ///
+ /// Returns the appropriate `Poll` result based on upstream state.
+ /// Small batches are passed through directly, large batches are stored
+ /// for slicing and return the first slice immediately.
+ fn poll_upstream(
+ &mut self,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ match self.input.as_mut().poll_next(cx) {
Review Comment:
You can use the `ready!` macro here to avoid some of the handling
https://docs.rs/futures/latest/futures/macro.ready.html
```suggestion
match ready!(self.input.as_mut().poll_next(cx) {
```
##########
datafusion/sqllogictest/test_files/window.slt:
##########
@@ -4341,6 +4341,9 @@ LIMIT 5;
24 31
14 94
+statement ok
Review Comment:
I think this one might be better off as batch_size = 1 to avoid the plan
changes later in this file
I see the challenge however.
##########
datafusion/sqllogictest/test_files/group_by.slt:
##########
@@ -4535,19 +4535,20 @@ LIMIT 5
query ITIPTR rowsort
SELECT r.*
FROM sales_global_with_pk as l, sales_global_with_pk as r
+ORDER BY 1, 2, 3, 4, 5, 6
Review Comment:
now that the input is in multiple batches, I agree the order by is required
##########
datafusion/sqllogictest/test_files/group_by.slt:
##########
@@ -3425,7 +3425,7 @@ physical_plan
06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8),
input_partitions=8
07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as
amount], aggr=[]
08)--------------RepartitionExec: partitioning=RoundRobinBatch(8),
input_partitions=1
-09)----------------DataSourceExec: partitions=1, partition_sizes=[1]
+09)----------------DataSourceExec: partitions=1, partition_sizes=[2]
Review Comment:
I believe the reason the partition sizes increases is that
`sales_global_with_pk` is created from the output if a query , and since the
batch size is 4, this results in more batches from the input.
I think this is a better plan overall
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]