Jefffrey commented on code in PR #19475:
URL: https://github.com/apache/datafusion/pull/19475#discussion_r2645099878


##########
datafusion/spark/src/function/hash/sha2.rs:
##########
@@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if arg_types[1].is_null() {
-            return Ok(DataType::Null);
-        }
-        Ok(match arg_types[0] {
-            DataType::Utf8View
-            | DataType::LargeUtf8
-            | DataType::Utf8
-            | DataType::Binary
-            | DataType::BinaryView
-            | DataType::LargeBinary => DataType::Utf8,
-            DataType::Null => DataType::Null,
-            _ => {
-                return exec_err!(
-                    "{} function can only accept strings or binary arrays.",
-                    self.name()
-                );
-            }
-        })
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Utf8)
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
-            internal_datafusion_err!("Expected 2 arguments for function sha2")
-        })?;
-
-        sha2(args)
+        make_scalar_function(sha2_impl, vec![])(&args.args)
     }
+}
 
-    fn aliases(&self) -> &[String] {
-        &self.aliases
-    }
+fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
+    let [values, bit_lengths] = take_function_args("sha2", args)?;
 
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        if arg_types.len() != 2 {
-            return Err(invalid_arg_count_exec_err(
-                self.name(),
-                (2, 2),
-                arg_types.len(),
-            ));
+    let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
+    let output = match values.data_type() {
+        DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), 
bit_lengths),
+        DataType::LargeBinary => {
+            sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
         }
-        let expr_type = match &arg_types[0] {
-            DataType::Utf8View
-            | DataType::LargeUtf8
-            | DataType::Utf8
-            | DataType::Binary
-            | DataType::BinaryView
-            | DataType::LargeBinary
-            | DataType::Null => Ok(arg_types[0].clone()),
-            _ => Err(unsupported_data_type_exec_err(
-                self.name(),
-                "String, Binary",
-                &arg_types[0],
-            )),
-        }?;
-        let bit_length_type = if arg_types[1].is_numeric() {
-            Ok(DataType::Int32)
-        } else if arg_types[1].is_null() {
-            Ok(DataType::Null)
-        } else {
-            Err(unsupported_data_type_exec_err(
-                self.name(),
-                "Numeric Type",
-                &arg_types[1],
-            ))
-        }?;
-
-        Ok(vec![expr_type, bit_length_type])
-    }
+        DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), 
bit_lengths),
+        dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
+    };
+    Ok(output)
 }
 
-pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
-    match args {
-        [
-            ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
-            ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
-        ] => compute_sha2(
-            bit_length_arg,
-            &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
-        ),
-        [
-            ColumnarValue::Array(expr_arg),
-            ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
-        ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]),
-        [
-            ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
-            ColumnarValue::Array(bit_length_arg),
-        ] => {
-            let arr: StringArray = bit_length_arg
-                .as_primitive::<Int32Type>()
-                .iter()
-                .map(|bit_length| {
-                    match sha2([
-                        
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
-                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
-                    ])
-                    .unwrap()
-                    {
-                        ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
-                        ColumnarValue::Array(arr) => arr
-                            .as_string::<i32>()
-                            .iter()
-                            .map(|str| str.unwrap().to_string())
-                            .next(), // first element
-                        _ => unreachable!(),
-                    }
-                })
-                .collect();
-            Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
-        }
-        [
-            ColumnarValue::Array(expr_arg),
-            ColumnarValue::Array(bit_length_arg),
-        ] => {
-            let expr_iter = expr_arg.as_string::<i32>().iter();
-            let bit_length_iter = 
bit_length_arg.as_primitive::<Int32Type>().iter();
-            let arr: StringArray = expr_iter
-                .zip(bit_length_iter)
-                .map(|(expr, bit_length)| {
-                    match sha2([
-                        ColumnarValue::Scalar(ScalarValue::Utf8(Some(
-                            expr.unwrap().to_string(),
-                        ))),
-                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
-                    ])
-                    .unwrap()
-                    {
-                        ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
-                        ColumnarValue::Array(arr) => arr
-                            .as_string::<i32>()
-                            .iter()
-                            .map(|str| str.unwrap().to_string())
-                            .next(), // first element
-                        _ => unreachable!(),
-                    }
-                })
-                .collect();
-            Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
-        }
-        _ => exec_err!("Unsupported argument types for sha2 function"),
-    }
+fn sha2_binary_impl<'a, BinaryArrType>(
+    values: &BinaryArrType,
+    bit_lengths: &Int32Array,
+) -> ArrayRef
+where
+    BinaryArrType: BinaryArrayType<'a>,
+{
+    let array = values
+        .iter()
+        .zip(bit_lengths.iter())
+        .map(|(value, bit_length)| match (value, bit_length) {
+            (Some(value), Some(224)) => {
+                let mut digest = sha2::Sha224::default();
+                digest.update(value);
+                Some(hex_encode(digest.finalize()))
+            }
+            (Some(value), Some(0 | 256)) => {
+                let mut digest = sha2::Sha256::default();
+                digest.update(value);
+                Some(hex_encode(digest.finalize()))
+            }
+            (Some(value), Some(384)) => {
+                let mut digest = sha2::Sha384::default();
+                digest.update(value);
+                Some(hex_encode(digest.finalize()))
+            }

Review Comment:
   We directly use the `sha2` crate to do the hashing now; previously we used 
functions exposed by datafusion-functions but that seemed like unnecessary 
indirection



##########
datafusion/spark/src/function/hash/sha2.rs:
##########
@@ -46,8 +50,21 @@ impl Default for SparkSha2 {
 impl SparkSha2 {
     pub fn new() -> Self {
         Self {
-            signature: Signature::user_defined(Volatility::Immutable),
-            aliases: vec![],
+            signature: Signature::coercible(
+                vec![
+                    Coercion::new_implicit(
+                        TypeSignatureClass::Native(logical_binary()),
+                        vec![TypeSignatureClass::Native(logical_string())],
+                        NativeType::Binary,
+                    ),
+                    Coercion::new_implicit(
+                        TypeSignatureClass::Native(logical_int32()),
+                        vec![TypeSignatureClass::Integer],
+                        NativeType::Int32,

Review Comment:
   Moving away from user_defined; also we cast strings to binary to simplify 
implementation as we only need raw bytes either way



##########
Cargo.toml:
##########
@@ -181,6 +181,7 @@ recursive = "0.1.1"
 regex = "1.12"
 rstest = "0.26.1"
 serde_json = "1"
+sha2 = "^0.10.9"

Review Comment:
   Because datafusion-spark now uses sha2 directly, we extract it as a common 
dependency



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to