alamb commented on code in PR #16734: URL: https://github.com/apache/datafusion/pull/16734#discussion_r2207881274
########## datafusion/physical-plan/src/stream.rs: ########## @@ -522,6 +524,139 @@ impl Stream for ObservedStream { } } +pin_project! { + /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches. + /// + /// This ensures upstream operators receive batches no larger than + /// `batch_size`, which can improve parallelism when data sources + /// generate very large batches. + /// + /// # Fields + /// + /// - `current_batch`: The batch currently being split, if any + /// - `offset`: Index of the next row to split from `current_batch`. + /// This tracks our position within the current batch being split. + /// + /// # Invariants + /// + /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some` + /// - When `current_batch` is `None`, `offset` is always 0 + /// - `batch_size` is always > 0 +pub struct BatchSplitStream { + #[pin] + input: SendableRecordBatchStream, + schema: SchemaRef, + batch_size: usize, + metrics: SplitMetrics, + current_batch: Option<RecordBatch>, + offset: usize, + } +} + +impl BatchSplitStream { + /// Create a new [`BatchSplitStream`] + pub fn new( + input: SendableRecordBatchStream, + batch_size: usize, + metrics: SplitMetrics, + ) -> Self { + let schema = input.schema(); + Self { + input, + schema, + batch_size, + metrics, + current_batch: None, + offset: 0, + } + } + + /// Attempt to produce the next sliced batch from the current batch. + /// + /// Returns `Some(batch)` if a slice was produced, `None` if the current batch + /// is exhausted and we need to poll upstream for more data. + fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> { + let batch = self.current_batch.take()?; + + // Assert slice boundary safety - offset should never exceed batch size + debug_assert!( + self.offset <= batch.num_rows(), + "Offset {} exceeds batch size {}", + self.offset, + batch.num_rows() + ); + + let remaining = batch.num_rows() - self.offset; + let to_take = remaining.min(self.batch_size); + let out = batch.slice(self.offset, to_take); + + self.metrics.batches_splitted.add(1); + self.offset += to_take; + if self.offset < batch.num_rows() { + // More data remains in this batch, store it back + self.current_batch = Some(batch); + } else { + // Batch is exhausted, reset offset + // Note: current_batch is already None since we took it at the start + self.offset = 0; + } + Some(Ok(out)) + } + + /// Poll the upstream input for the next batch. + /// + /// Returns the appropriate `Poll` result based on upstream state. + /// Small batches are passed through directly, large batches are stored + /// for slicing and return the first slice immediately. + fn poll_upstream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<RecordBatch>>> { + match self.input.as_mut().poll_next(cx) { Review Comment: You can use the `ready!` macro here to avoid some of the handling https://docs.rs/futures/latest/futures/macro.ready.html ```suggestion match ready!(self.input.as_mut().poll_next(cx) { ``` ########## datafusion/sqllogictest/test_files/window.slt: ########## @@ -4341,6 +4341,9 @@ LIMIT 5; 24 31 14 94 +statement ok Review Comment: I think this one might be better off as batch_size = 1 to avoid the plan changes later in this file I see the challenge however. ########## datafusion/sqllogictest/test_files/group_by.slt: ########## @@ -4535,19 +4535,20 @@ LIMIT 5 query ITIPTR rowsort SELECT r.* FROM sales_global_with_pk as l, sales_global_with_pk as r +ORDER BY 1, 2, 3, 4, 5, 6 Review Comment: now that the input is in multiple batches, I agree the order by is required ########## datafusion/sqllogictest/test_files/group_by.slt: ########## @@ -3425,7 +3425,7 @@ physical_plan 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +09)----------------DataSourceExec: partitions=1, partition_sizes=[2] Review Comment: I believe the reason the partition sizes increases is that `sales_global_with_pk` is created from the output if a query , and since the batch size is 4, this results in more batches from the input. I think this is a better plan overall -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org