rluvaton commented on code in PR #19494:
URL: https://github.com/apache/datafusion/pull/19494#discussion_r2650909825
##########
datafusion/physical-plan/src/sorts/sort.rs:
##########
@@ -2402,4 +2454,757 @@ mod tests {
Ok((sorted_batches, metrics))
}
+
+ // ========================================================================
+ // Tests for sort_batch_chunked()
+ // ========================================================================
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_basic() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 1000 rows
+ let mut values: Vec<i32> = (0..1000).collect();
+ // Shuffle to make it unsorted
+ values.reverse();
+
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 250
+ let result_batches = sort_batch_chunked(&batch, &expressions, 250)?;
+
+ // Verify 4 batches are returned
+ assert_eq!(result_batches.len(), 4);
+
+ // Verify each batch has <= 250 rows
+ let mut total_rows = 0;
+ for (i, batch) in result_batches.iter().enumerate() {
+ assert!(
+ batch.num_rows() <= 250,
+ "Batch {} has {} rows, expected <= 250",
+ i,
+ batch.num_rows()
+ );
+ total_rows += batch.num_rows();
+ }
+
+ // Verify total row count matches input
+ assert_eq!(total_rows, 1000);
+
+ // Verify data is correctly sorted across all chunks
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(
+ array.value(i) <= array.value(i + 1),
+ "Array not sorted at position {}: {} > {}",
+ i,
+ array.value(i),
+ array.value(i + 1)
+ );
+ }
+ assert_eq!(array.value(0), 0);
+ assert_eq!(array.value(array.len() - 1), 999);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 50 rows
+ let values: Vec<i32> = (0..50).rev().collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 100
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Should return exactly 1 batch
+ assert_eq!(result_batches.len(), 1);
+ assert_eq!(result_batches[0].num_rows(), 50);
+
+ // Verify it's correctly sorted
+ let array =
as_primitive_array::<Int32Type>(result_batches[0].column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(array.value(i) <= array.value(i + 1));
+ }
+ assert_eq!(array.value(0), 0);
+ assert_eq!(array.value(49), 49);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_exact_multiple() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 1000 rows
+ let values: Vec<i32> = (0..1000).rev().collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 100
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Should return exactly 10 batches of 100 rows each
+ assert_eq!(result_batches.len(), 10);
+ for batch in &result_batches {
+ assert_eq!(batch.num_rows(), 100);
+ }
+
+ // Verify sorted correctly across all batches
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(array.value(i) <= array.value(i + 1));
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_with_nulls() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, true)]));
+
+ // Create a batch with nulls
+ let values = Int32Array::from(vec![
+ Some(5),
+ None,
+ Some(2),
+ Some(8),
+ None,
+ Some(1),
+ Some(9),
+ None,
+ Some(3),
+ Some(7),
+ ]);
+ let batch = RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(values)])?;
+
+ // Test with nulls_first = true
+ {
+ let expressions: LexOrdering = [PhysicalSortExpr {
+ expr: Arc::new(Column::new("a", 0)),
+ options: SortOptions {
+ descending: false,
+ nulls_first: true,
+ },
+ }]
+ .into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 4)?;
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array =
as_primitive_array::<Int32Type>(concatenated.column(0))?;
+
+ // First 3 should be null
+ assert!(array.is_null(0));
+ assert!(array.is_null(1));
+ assert!(array.is_null(2));
+ // Then sorted values
+ assert_eq!(array.value(3), 1);
+ assert_eq!(array.value(4), 2);
+ }
+
+ // Test with nulls_first = false
+ {
+ let expressions: LexOrdering = [PhysicalSortExpr {
+ expr: Arc::new(Column::new("a", 0)),
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
+ },
+ }]
+ .into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 4)?;
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array =
as_primitive_array::<Int32Type>(concatenated.column(0))?;
+
+ // First should be 1
+ assert_eq!(array.value(0), 1);
+ // Last 3 should be null
+ assert!(array.is_null(7));
+ assert!(array.is_null(8));
+ assert!(array.is_null(9));
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_multi_column() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ ]));
+
+ // Create a batch with multiple columns
+ let a_values = Int32Array::from(vec![3, 1, 2, 1, 3, 2, 1, 3, 2, 1]);
+ let b_values = Int32Array::from(vec![1, 2, 3, 1, 2, 1, 3, 3, 2, 4]);
+
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(a_values), Arc::new(b_values)],
+ )?;
+
+ let expressions: LexOrdering = [
+ PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
+ PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))),
+ ]
+ .into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 3)?;
+ let concatenated = concat_batches(&schema, &result_batches)?;
+
+ let a_array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ let b_array = as_primitive_array::<Int32Type>(concatenated.column(1))?;
+
+ // Verify multi-column sort ordering
+ for i in 0..a_array.len() - 1 {
+ let a_curr = a_array.value(i);
+ let a_next = a_array.value(i + 1);
+ let b_curr = b_array.value(i);
+ let b_next = b_array.value(i + 1);
+
+ assert!(
+ a_curr < a_next || (a_curr == a_next && b_curr <= b_next),
+ "Not properly sorted at position {i}: ({a_curr}, {b_curr}) ->
({a_next}, {b_next})",
+ );
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_empty_batch() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ let batch = RecordBatch::new_empty(Arc::clone(&schema));
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Empty input produces no output batches (0 chunks)
+ assert_eq!(result_batches.len(), 0);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_single_row() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(vec![42]))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ assert_eq!(result_batches.len(), 1);
+ assert_eq!(result_batches[0].num_rows(), 1);
+ let array =
as_primitive_array::<Int32Type>(result_batches[0].column(0))?;
+ assert_eq!(array.value(0), 42);
+
+ Ok(())
+ }
+
+ // ========================================================================
+ // Tests for get_reserved_byte_for_record_batch()
+ // ========================================================================
+
+ #[tokio::test]
+ async fn test_get_reserved_byte_for_record_batch_normal_batch() ->
Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+ let batch = RecordBatch::try_new(
+ schema,
+ vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
+ )?;
+
+ let reserved = get_reserved_byte_for_record_batch(&batch)?;
+
+ // Should be greater than 0
+ assert!(reserved > 0);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_get_reserved_byte_for_record_batch_with_sliced_batches() ->
Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a larger batch then slice it
+ let large_array = Int32Array::from((0..1000).collect::<Vec<i32>>());
+ let sliced_array = large_array.slice(100, 50); // Take 50 elements
starting at 100
+
+ let batch =
+ RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(sliced_array)])?;
+
+ let reserved = get_reserved_byte_for_record_batch(&batch)?;
+
+ // Reserved should account for the sliced nature
+ assert!(reserved > 0);
+
+ // The reservation should include memory for the full underlying buffer
+ // plus the sliced size rounded to 64
+ let record_batch_size = get_record_batch_memory_size(&batch);
+ assert!(reserved >= record_batch_size);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_get_reserved_byte_for_record_batch_rounding() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with a size that's not a multiple of 64
+ let batch = RecordBatch::try_new(
+ schema,
+ vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
+ )?;
+
+ let reserved = get_reserved_byte_for_record_batch(&batch)?;
+
+ // Should be rounded to multiple of 64
+ assert!(reserved > 0);
+ // The rounding is applied to the sliced size component
+ // Total = record_batch_memory_size +
round_upto_multiple_of_64(sliced_size)
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_get_reserved_byte_for_record_batch_with_string_view() ->
Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "a",
+ DataType::Utf8View,
+ false,
+ )]));
+
+ let string_array = StringViewArray::from(vec!["hello", "world",
"test"]);
+ let batch = RecordBatch::try_new(schema,
vec![Arc::new(string_array)])?;
+
+ let reserved = get_reserved_byte_for_record_batch(&batch)?;
+
+ // Should handle variable-length data correctly
+ assert!(reserved > 0);
+
+ Ok(())
+ }
+
+ // ========================================================================
+ // Tests for ReservationStream (in stream.rs, but we test integration here)
+ // ========================================================================
+
+ #[tokio::test]
+ async fn test_sort_batch_stream_memory_tracking() -> Result<()> {
+ use crate::stream::ReservationStream;
+
+ let runtime = RuntimeEnvBuilder::new()
+ .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB limit
+ .build_arc()?;
+
+ let reservation =
MemoryConsumer::new("test").register(&runtime.memory_pool);
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch
+ let values: Vec<i32> = (0..1000).collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let batch_size = get_record_batch_memory_size(&batch);
+
+ // Create a simple stream with one batch
+ let stream = futures::stream::iter(vec![Ok(batch)]);
+ let inner =
Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
+ as SendableRecordBatchStream;
+
+ // Create reservation and grow it
+ let mut reservation = reservation;
+ reservation.try_grow(batch_size)?;
+
+ let initial_reserved = runtime.memory_pool.reserved();
+ assert!(initial_reserved > 0);
+
+ // Create ReservationStream
+ let mut res_stream =
+ ReservationStream::new(Arc::clone(&schema), inner, reservation);
+
+ // Consume the batch
+ let result = res_stream.next().await;
+ assert!(result.is_some());
+
+ // Memory should be reduced after consuming
+ let after_consume = runtime.memory_pool.reserved();
+ assert!(after_consume < initial_reserved);
+
+ // Consume until end
+ while res_stream.next().await.is_some() {}
+
+ // Memory should be freed
+ let final_reserved = runtime.memory_pool.reserved();
+ assert_eq!(final_reserved, 0);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_stream_chunked_output() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a large batch (5000 rows)
+ let values: Vec<i32> = (0..5000).rev().collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ let batch_size = 500;
+ let result_batches = sort_batch_chunked(&batch, &expressions,
batch_size)?;
+
+ // Verify multiple output batches
+ assert_eq!(result_batches.len(), 10);
+
+ // Each batch should be <= batch_size
+ let mut total_rows = 0;
+ for batch in &result_batches {
+ assert!(batch.num_rows() <= batch_size);
+ total_rows += batch.num_rows();
+ }
+
+ assert_eq!(total_rows, 5000);
+
+ // Verify data is sorted
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(array.value(i) <= array.value(i + 1));
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_no_batch_split_stream_metrics() -> Result<()> {
+ let task_ctx = Arc::new(TaskContext::default());
+ let partitions = 2;
+ let csv = test::scan_partitioned(partitions);
+ let schema = csv.schema();
+
+ let sort_exec = Arc::new(SortExec::new(
+ [PhysicalSortExpr {
+ expr: col("i", &schema)?,
+ options: SortOptions::default(),
+ }]
+ .into(),
+ Arc::new(CoalescePartitionsExec::new(csv)),
+ ));
+
+ let _result = collect(Arc::clone(&sort_exec) as _, task_ctx).await?;
+
+ let metrics = sort_exec.metrics().unwrap();
+
+ // Verify that SplitMetrics are not present
+ // The metrics should only include baseline and spill metrics
+ let metrics_str = format!("{metrics:?}");
+
+ // Should not contain split-related metrics
+ assert!(
+ !metrics_str.contains("split_count"),
+ "Should not have split_count metric"
+ );
+ assert!(
+ !metrics_str.contains("split_time"),
+ "Should not have split_time metric"
+ );
+
+ // Should still have baseline and spill metrics
+ assert!(metrics.output_rows().is_some());
+ assert!(metrics.elapsed_compute().is_some());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_external_sorter_with_chunked_batches() -> Result<()> {
+ // Test with memory limits that trigger spilling
Review Comment:
We already have tests for that
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]