This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 92727b52fd fix: shuffle seed (#18518)
92727b52fd is described below

commit 92727b52fd0f9d3896be0bb8b09ab670c7d31d4d
Author: Chen Chongchen <[email protected]>
AuthorDate: Sat Nov 8 08:17:33 2025 +0800

    fix: shuffle seed (#18518)
    
    ## Which issue does this PR close?
    
    - Closes #18476.
    
    ## Rationale for this change
    
    shuffle test sometimes fails
    
    ## What changes are included in this PR?
    
    add seed to shuffle, make sure slt won't fail.
    
    ## Are these changes tested?
    
    UT
    
    ## Are there any user-facing changes?
    
    No
---
 datafusion/spark/src/function/array/shuffle.rs     | 102 ++++++++++++++++++---
 .../test_files/spark/array/shuffle.slt             |  46 +++-------
 2 files changed, 103 insertions(+), 45 deletions(-)

diff --git a/datafusion/spark/src/function/array/shuffle.rs 
b/datafusion/spark/src/function/array/shuffle.rs
index abeafd3a93..9f345b53b8 100644
--- a/datafusion/spark/src/function/array/shuffle.rs
+++ b/datafusion/spark/src/function/array/shuffle.rs
@@ -15,21 +15,25 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::function::functions_nested_utils::make_scalar_function;
 use arrow::array::{
     Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, 
MutableArrayData,
     OffsetSizeTrait,
 };
 use arrow::buffer::OffsetBuffer;
+use arrow::datatypes::DataType;
 use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
-use arrow::datatypes::{DataType, FieldRef};
+use arrow::datatypes::FieldRef;
 use datafusion_common::cast::{
     as_fixed_size_list_array, as_large_list_array, as_list_array,
 };
-use datafusion_common::{exec_err, utils::take_function_args, Result};
-use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+use datafusion_common::{exec_err, utils::take_function_args, Result, 
ScalarValue};
+use datafusion_expr::{
+    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, 
ScalarUDFImpl,
+    Signature, TypeSignature, Volatility,
+};
 use rand::rng;
-use rand::seq::SliceRandom;
+use rand::rngs::StdRng;
+use rand::{seq::SliceRandom, Rng, SeedableRng};
 use std::any::Any;
 use std::sync::Arc;
 
@@ -47,7 +51,25 @@ impl Default for SparkShuffle {
 impl SparkShuffle {
     pub fn new() -> Self {
         Self {
-            signature: Signature::arrays(1, None, Volatility::Volatile),
+            signature: Signature {
+                type_signature: TypeSignature::OneOf(vec![
+                    // Only array argument
+                    
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+                        arguments: vec![ArrayFunctionArgument::Array],
+                        array_coercion: None,
+                    }),
+                    // Array + Index (seed) argument
+                    
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+                        arguments: vec![
+                            ArrayFunctionArgument::Array,
+                            ArrayFunctionArgument::Index,
+                        ],
+                        array_coercion: None,
+                    }),
+                ]),
+                volatility: Volatility::Volatile,
+                parameter_names: None,
+            },
         }
     }
 }
@@ -73,25 +95,63 @@ impl ScalarUDFImpl for SparkShuffle {
         &self,
         args: datafusion_expr::ScalarFunctionArgs,
     ) -> Result<ColumnarValue> {
-        make_scalar_function(array_shuffle_inner)(&args.args)
+        if args.args.is_empty() {
+            return exec_err!("shuffle expects at least 1 argument");
+        }
+        if args.args.len() > 2 {
+            return exec_err!("shuffle expects at most 2 arguments");
+        }
+
+        // Extract seed from second argument if present
+        let seed = if args.args.len() == 2 {
+            extract_seed(&args.args[1])?
+        } else {
+            None
+        };
+
+        // Convert arguments to arrays
+        let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
+        array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
+    }
+}
+
+/// Extract seed value from ColumnarValue
+fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
+    match seed_arg {
+        ColumnarValue::Scalar(scalar) => {
+            let seed = match scalar {
+                ScalarValue::Int64(Some(v)) => Some(*v as u64),
+                ScalarValue::Null => None,
+                _ => {
+                    return exec_err!(
+                        "shuffle seed must be Int64 type, got '{}'",
+                        scalar.data_type()
+                    );
+                }
+            };
+            Ok(seed)
+        }
+        ColumnarValue::Array(_) => {
+            exec_err!("shuffle seed must be a scalar value, not an array")
+        }
     }
 }
 
-/// array_shuffle SQL function
-pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
+/// array_shuffle SQL function with optional seed
+fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> 
Result<ArrayRef> {
     let [input_array] = take_function_args("shuffle", arg)?;
     match &input_array.data_type() {
         List(field) => {
             let array = as_list_array(input_array)?;
-            general_array_shuffle::<i32>(array, field)
+            general_array_shuffle::<i32>(array, field, seed)
         }
         LargeList(field) => {
             let array = as_large_list_array(input_array)?;
-            general_array_shuffle::<i64>(array, field)
+            general_array_shuffle::<i64>(array, field, seed)
         }
         FixedSizeList(field, _) => {
             let array = as_fixed_size_list_array(input_array)?;
-            fixed_size_array_shuffle(array, field)
+            fixed_size_array_shuffle(array, field, seed)
         }
         Null => Ok(Arc::clone(input_array)),
         array_type => exec_err!("shuffle does not support type 
'{array_type}'."),
@@ -101,6 +161,7 @@ pub fn array_shuffle_inner(arg: &[ArrayRef]) -> 
Result<ArrayRef> {
 fn general_array_shuffle<O: OffsetSizeTrait>(
     array: &GenericListArray<O>,
     field: &FieldRef,
+    seed: Option<u64>,
 ) -> Result<ArrayRef> {
     let values = array.values();
     let original_data = values.to_data();
@@ -109,7 +170,13 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
     let mut nulls = vec![];
     let mut mutable =
         MutableArrayData::with_capacities(vec![&original_data], false, 
capacity);
-    let mut rng = rng();
+    let mut rng = if let Some(s) = seed {
+        StdRng::seed_from_u64(s)
+    } else {
+        // Use a random seed from the thread-local RNG
+        let seed = rng().random::<u64>();
+        StdRng::seed_from_u64(seed)
+    };
 
     for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
         // skip the null value
@@ -149,6 +216,7 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
 fn fixed_size_array_shuffle(
     array: &FixedSizeListArray,
     field: &FieldRef,
+    seed: Option<u64>,
 ) -> Result<ArrayRef> {
     let values = array.values();
     let original_data = values.to_data();
@@ -157,7 +225,13 @@ fn fixed_size_array_shuffle(
     let mut mutable =
         MutableArrayData::with_capacities(vec![&original_data], false, 
capacity);
     let value_length = array.value_length() as usize;
-    let mut rng = rng();
+    let mut rng = if let Some(s) = seed {
+        StdRng::seed_from_u64(s)
+    } else {
+        // Use a random seed from the thread-local RNG
+        let seed = rng().random::<u64>();
+        StdRng::seed_from_u64(seed)
+    };
 
     for row_index in 0..array.len() {
         // skip the null value
diff --git a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt 
b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
index 7614caef66..35aad58144 100644
--- a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
+++ b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt
@@ -16,27 +16,16 @@
 # under the License.
 
 # Test shuffle function with simple arrays
-query B
-SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5];
-----
-true
-
-query B
-SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL];
+query ?
+SELECT shuffle([1, 2, 3, 4, 5, NULL], 1);
 ----
-true
+[1, 4, NULL, 2, 5, 3]
 
 # Test shuffle function with string arrays
-
-query B
-SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c', 
'd', 'e', 'f'];
-----
-true
-
-query B
-SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e', 
'f'];;
+query ?
+SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f'], 1);
 ----
-true
+[a, d, f, b, e, c]
 
 # Test shuffle function with empty array
 query ?
@@ -57,15 +46,10 @@ SELECT shuffle(NULL);
 NULL
 
 # Test shuffle function with fixed size list arrays
-query B
-SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, 
Int64)'))) = [NULL, 1, 2, 3, 4, 5];
-----
-true
-
-query B
-SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)')) 
!= [1, 2, NULL, 3, 4, 5];
+query ?
+SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'), 
1);
 ----
-true
+[1, 3, 5, 2, 4, NULL]
 
 # Test shuffle on table data with different list types
 statement ok
@@ -78,10 +62,10 @@ CREATE TABLE test_shuffle_list_types AS VALUES
 
 # Test shuffle with large list from table
 query ?
-SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types;
+SELECT shuffle(column1, 1) FROM test_shuffle_list_types;
 ----
-[1, 2, 3, 4]
-[5, 6, 7, 8, 9]
+[1, 4, 3, 2]
+[8, 9, 6, 5, 7]
 [10]
 NULL
 []
@@ -96,11 +80,11 @@ CREATE TABLE test_shuffle_fixed_size AS VALUES
 
 # Test shuffle with fixed size list from table
 query ?
-SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size;
+SELECT shuffle(column1, 1) FROM test_shuffle_fixed_size;
 ----
 [1, 2, 3]
-[4, 5, 6]
-[NULL, 8, 9]
+[4, 6, 5]
+[9, NULL, 8]
 NULL
 
 # Clean up


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to