This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 92727b52fd fix: shuffle seed (#18518)
92727b52fd is described below
commit 92727b52fd0f9d3896be0bb8b09ab670c7d31d4d
Author: Chen Chongchen <[email protected]>
AuthorDate: Sat Nov 8 08:17:33 2025 +0800
fix: shuffle seed (#18518)
## Which issue does this PR close?
- Closes #18476.
## Rationale for this change
shuffle test sometimes fails
## What changes are included in this PR?
add seed to shuffle, make sure slt won't fail.
## Are these changes tested?
UT
## Are there any user-facing changes?
No
---
datafusion/spark/src/function/array/shuffle.rs | 102 ++++++++++++++++++---
.../test_files/spark/array/shuffle.slt | 46 +++-------
2 files changed, 103 insertions(+), 45 deletions(-)
diff --git a/datafusion/spark/src/function/array/shuffle.rs
b/datafusion/spark/src/function/array/shuffle.rs
index abeafd3a93..9f345b53b8 100644
--- a/datafusion/spark/src/function/array/shuffle.rs
+++ b/datafusion/spark/src/function/array/shuffle.rs
@@ -15,21 +15,25 @@
// specific language governing permissions and limitations
// under the License.
-use crate::function::functions_nested_utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray,
MutableArrayData,
OffsetSizeTrait,
};
use arrow::buffer::OffsetBuffer;
+use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
-use arrow::datatypes::{DataType, FieldRef};
+use arrow::datatypes::FieldRef;
use datafusion_common::cast::{
as_fixed_size_list_array, as_large_list_array, as_list_array,
};
-use datafusion_common::{exec_err, utils::take_function_args, Result};
-use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+use datafusion_common::{exec_err, utils::take_function_args, Result,
ScalarValue};
+use datafusion_expr::{
+ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue,
ScalarUDFImpl,
+ Signature, TypeSignature, Volatility,
+};
use rand::rng;
-use rand::seq::SliceRandom;
+use rand::rngs::StdRng;
+use rand::{seq::SliceRandom, Rng, SeedableRng};
use std::any::Any;
use std::sync::Arc;
@@ -47,7 +51,25 @@ impl Default for SparkShuffle {
impl SparkShuffle {
pub fn new() -> Self {
Self {
- signature: Signature::arrays(1, None, Volatility::Volatile),
+ signature: Signature {
+ type_signature: TypeSignature::OneOf(vec![
+ // Only array argument
+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![ArrayFunctionArgument::Array],
+ array_coercion: None,
+ }),
+ // Array + Index (seed) argument
+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion: None,
+ }),
+ ]),
+ volatility: Volatility::Volatile,
+ parameter_names: None,
+ },
}
}
}
@@ -73,25 +95,63 @@ impl ScalarUDFImpl for SparkShuffle {
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
- make_scalar_function(array_shuffle_inner)(&args.args)
+ if args.args.is_empty() {
+ return exec_err!("shuffle expects at least 1 argument");
+ }
+ if args.args.len() > 2 {
+ return exec_err!("shuffle expects at most 2 arguments");
+ }
+
+ // Extract seed from second argument if present
+ let seed = if args.args.len() == 2 {
+ extract_seed(&args.args[1])?
+ } else {
+ None
+ };
+
+ // Convert arguments to arrays
+ let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
+ array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
+ }
+}
+
+/// Extract seed value from ColumnarValue
+fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
+ match seed_arg {
+ ColumnarValue::Scalar(scalar) => {
+ let seed = match scalar {
+ ScalarValue::Int64(Some(v)) => Some(*v as u64),
+ ScalarValue::Null => None,
+ _ => {
+ return exec_err!(
+ "shuffle seed must be Int64 type, got '{}'",
+ scalar.data_type()
+ );
+ }
+ };
+ Ok(seed)
+ }
+ ColumnarValue::Array(_) => {
+ exec_err!("shuffle seed must be a scalar value, not an array")
+ }
}
}
-/// array_shuffle SQL function
-pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
+/// array_shuffle SQL function with optional seed
+fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) ->
Result<ArrayRef> {
let [input_array] = take_function_args("shuffle", arg)?;
match &input_array.data_type() {
List(field) => {
let array = as_list_array(input_array)?;
- general_array_shuffle::<i32>(array, field)
+ general_array_shuffle::<i32>(array, field, seed)
}
LargeList(field) => {
let array = as_large_list_array(input_array)?;
- general_array_shuffle::<i64>(array, field)
+ general_array_shuffle::<i64>(array, field, seed)
}
FixedSizeList(field, _) => {
let array = as_fixed_size_list_array(input_array)?;
- fixed_size_array_shuffle(array, field)
+ fixed_size_array_shuffle(array, field, seed)
}
Null => Ok(Arc::clone(input_array)),
array_type => exec_err!("shuffle does not support type
'{array_type}'."),
@@ -101,6 +161,7 @@ pub fn array_shuffle_inner(arg: &[ArrayRef]) ->
Result<ArrayRef> {
fn general_array_shuffle<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
field: &FieldRef,
+ seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
@@ -109,7 +170,13 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
let mut nulls = vec![];
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false,
capacity);
- let mut rng = rng();
+ let mut rng = if let Some(s) = seed {
+ StdRng::seed_from_u64(s)
+ } else {
+ // Use a random seed from the thread-local RNG
+ let seed = rng().random::<u64>();
+ StdRng::seed_from_u64(seed)
+ };
for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
// skip the null value
@@ -149,6 +216,7 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
fn fixed_size_array_shuffle(
array: &FixedSizeListArray,
field: &FieldRef,
+ seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
@@ -157,7 +225,13 @@ fn fixed_size_array_shuffle(
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false,
capacity);
let value_length = array.value_length() as usize;
- let mut rng = rng();
+ let mut rng = if let Some(s) = seed {
+ StdRng::seed_from_u64(s)
+ } else {
+ // Use a random seed from the thread-local RNG
+ let seed = rng().random::<u64>();
+ StdRng::seed_from_u64(seed)
+ };
for row_index in 0..array.len() {
// skip the null value
diff --git a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
index 7614caef66..35aad58144 100644
--- a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
+++ b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
@@ -16,27 +16,16 @@
# under the License.
# Test shuffle function with simple arrays
-query B
-SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5];
-----
-true
-
-query B
-SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL];
+query ?
+SELECT shuffle([1, 2, 3, 4, 5, NULL], 1);
----
-true
+[1, 4, NULL, 2, 5, 3]
# Test shuffle function with string arrays
-
-query B
-SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c',
'd', 'e', 'f'];
-----
-true
-
-query B
-SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e',
'f'];;
+query ?
+SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f'], 1);
----
-true
+[a, d, f, b, e, c]
# Test shuffle function with empty array
query ?
@@ -57,15 +46,10 @@ SELECT shuffle(NULL);
NULL
# Test shuffle function with fixed size list arrays
-query B
-SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6,
Int64)'))) = [NULL, 1, 2, 3, 4, 5];
-----
-true
-
-query B
-SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'))
!= [1, 2, NULL, 3, 4, 5];
+query ?
+SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'),
1);
----
-true
+[1, 3, 5, 2, 4, NULL]
# Test shuffle on table data with different list types
statement ok
@@ -78,10 +62,10 @@ CREATE TABLE test_shuffle_list_types AS VALUES
# Test shuffle with large list from table
query ?
-SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types;
+SELECT shuffle(column1, 1) FROM test_shuffle_list_types;
----
-[1, 2, 3, 4]
-[5, 6, 7, 8, 9]
+[1, 4, 3, 2]
+[8, 9, 6, 5, 7]
[10]
NULL
[]
@@ -96,11 +80,11 @@ CREATE TABLE test_shuffle_fixed_size AS VALUES
# Test shuffle with fixed size list from table
query ?
-SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size;
+SELECT shuffle(column1, 1) FROM test_shuffle_fixed_size;
----
[1, 2, 3]
-[4, 5, 6]
-[NULL, 8, 9]
+[4, 6, 5]
+[9, NULL, 8]
NULL
# Clean up
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]