This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 8485558  feat: Support murmur3_hash and sha2 family hash functions 
(#226)
8485558 is described below

commit 848555818713ca9dff6fedfdf6e407969f1eeca2
Author: advancedxy <[email protected]>
AuthorDate: Fri Apr 26 13:02:47 2024 +0800

    feat: Support murmur3_hash and sha2 family hash functions (#226)
    
    * feat: Support murmur3_hash and sha2 family hash functions
    
    * address comments
    
    * apply scalafix
    
    * ensure crypto_expressions feature is enabled
---
 core/Cargo.toml                                    |   4 +-
 .../datafusion/expressions/scalar_funcs.rs         | 206 ++++++++++++++-------
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  43 ++++-
 .../org/apache/comet/CometExpressionSuite.scala    |  26 ++-
 4 files changed, 205 insertions(+), 74 deletions(-)

diff --git a/core/Cargo.toml b/core/Cargo.toml
index 5d16049..b09b0ea 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -67,8 +67,8 @@ chrono = { version = "0.4", default-features = false, 
features = ["clock"] }
 chrono-tz = { version = "0.8" }
 paste = "1.0.14"
 datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git";, 
rev = "57b3be4" }
-datafusion = { default-features = false, git = 
"https://github.com/viirya/arrow-datafusion.git";, rev = "57b3be4", features = 
["unicode_expressions"] }
-datafusion-functions = { git = 
"https://github.com/viirya/arrow-datafusion.git";, rev = "57b3be4" }
+datafusion = { default-features = false, git = 
"https://github.com/viirya/arrow-datafusion.git";, rev = "57b3be4", features = 
["unicode_expressions", "crypto_expressions"] }
+datafusion-functions = { git = 
"https://github.com/viirya/arrow-datafusion.git";, rev = "57b3be4", features = 
["crypto_expressions"]}
 datafusion-physical-expr = { git = 
"https://github.com/viirya/arrow-datafusion.git";, rev = "57b3be4", 
default-features = false, features = ["unicode_expressions"] }
 unicode-segmentation = "^1.10.1"
 once_cell = "1.18.0"
diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs 
b/core/src/execution/datafusion/expressions/scalar_funcs.rs
index e6f8de1..2895937 100644
--- a/core/src/execution/datafusion/expressions/scalar_funcs.rs
+++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs
@@ -15,8 +15,15 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{any::Any, cmp::min, fmt::Debug, str::FromStr, sync::Arc};
+use std::{
+    any::Any,
+    cmp::min,
+    fmt::{Debug, Write},
+    str::FromStr,
+    sync::Arc,
+};
 
+use crate::execution::datafusion::spark_hash::create_hashes;
 use arrow::{
     array::{
         ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, 
GenericStringArray,
@@ -24,7 +31,7 @@ use arrow::{
     },
     datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
 };
-use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
+use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray};
 use arrow_schema::DataType;
 use datafusion::{
     execution::FunctionRegistry,
@@ -35,8 +42,8 @@ use datafusion::{
     physical_plan::ColumnarValue,
 };
 use datafusion_common::{
-    cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
-    Result as DataFusionResult, ScalarValue,
+    cast::{as_binary_array, as_generic_string_array},
+    exec_err, internal_err, DataFusionError, Result as DataFusionResult, 
ScalarValue,
 };
 use datafusion_physical_expr::{math_expressions, udf::ScalarUDF};
 use num::{
@@ -45,89 +52,75 @@ use num::{
 };
 use unicode_segmentation::UnicodeSegmentation;
 
+macro_rules! make_comet_scalar_udf {
+    ($name:expr, $func:ident, $data_type:ident) => {{
+        let scalar_func = CometScalarFunction::new(
+            $name.to_string(),
+            Signature::variadic_any(Volatility::Immutable),
+            $data_type.clone(),
+            Arc::new(move |args| $func(args, &$data_type)),
+        );
+        Ok(ScalarFunctionDefinition::UDF(Arc::new(
+            ScalarUDF::new_from_impl(scalar_func),
+        )))
+    }};
+    ($name:expr, $func:expr, without $data_type:ident) => {{
+        let scalar_func = CometScalarFunction::new(
+            $name.to_string(),
+            Signature::variadic_any(Volatility::Immutable),
+            $data_type,
+            $func,
+        );
+        Ok(ScalarFunctionDefinition::UDF(Arc::new(
+            ScalarUDF::new_from_impl(scalar_func),
+        )))
+    }};
+}
+
 /// Create a physical scalar function.
 pub fn create_comet_physical_fun(
     fun_name: &str,
     data_type: DataType,
     registry: &dyn FunctionRegistry,
 ) -> Result<ScalarFunctionDefinition, DataFusionError> {
+    let sha2_functions = ["sha224", "sha256", "sha384", "sha512"];
     match fun_name {
         "ceil" => {
-            let scalar_func = CometScalarFunction::new(
-                "ceil".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(move |args| spark_ceil(args, &data_type)),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            make_comet_scalar_udf!("ceil", spark_ceil, data_type)
         }
         "floor" => {
-            let scalar_func = CometScalarFunction::new(
-                "floor".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(move |args| spark_floor(args, &data_type)),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            make_comet_scalar_udf!("floor", spark_floor, data_type)
         }
         "rpad" => {
-            let scalar_func = CometScalarFunction::new(
-                "rpad".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(spark_rpad),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            let func = Arc::new(spark_rpad);
+            make_comet_scalar_udf!("rpad", func, without data_type)
         }
         "round" => {
-            let scalar_func = CometScalarFunction::new(
-                "round".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(move |args| spark_round(args, &data_type)),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            make_comet_scalar_udf!("round", spark_round, data_type)
         }
         "unscaled_value" => {
-            let scalar_func = CometScalarFunction::new(
-                "unscaled_value".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(spark_unscaled_value),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            let func = Arc::new(spark_unscaled_value);
+            make_comet_scalar_udf!("unscaled_value", func, without data_type)
         }
         "make_decimal" => {
-            let scalar_func = CometScalarFunction::new(
-                "make_decimal".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(move |args| spark_make_decimal(args, &data_type)),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            make_comet_scalar_udf!("make_decimal", spark_make_decimal, 
data_type)
         }
         "decimal_div" => {
-            let scalar_func = CometScalarFunction::new(
-                "decimal_div".to_string(),
-                Signature::variadic_any(Volatility::Immutable),
-                data_type.clone(),
-                Arc::new(move |args| spark_decimal_div(args, &data_type)),
-            );
-            Ok(ScalarFunctionDefinition::UDF(Arc::new(
-                ScalarUDF::new_from_impl(scalar_func),
-            )))
+            make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
+        }
+        "murmur3_hash" => {
+            let func = Arc::new(spark_murmur3_hash);
+            make_comet_scalar_udf!("murmur3_hash", func, without data_type)
+        }
+        sha if sha2_functions.contains(&sha) => {
+            // Spark requires hex string as the result of sha2 functions, we 
have to wrap the
+            // result of digest functions as hex string
+            let func = registry.udf(sha)?;
+            let wrapped_func = Arc::new(move |args: &[ColumnarValue]| {
+                wrap_digest_result_as_hex_string(args, func.fun())
+            });
+            let spark_func_name = "spark".to_owned() + sha;
+            make_comet_scalar_udf!(spark_func_name, wrapped_func, without 
data_type)
         }
         _ => {
             let fun = BuiltinScalarFunction::from_str(fun_name);
@@ -629,3 +622,82 @@ fn spark_decimal_div(
     let result = result.with_data_type(DataType::Decimal128(p3, s3));
     Ok(ColumnarValue::Array(Arc::new(result)))
 }
+
+fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    let length = args.len();
+    let seed = &args[length - 1];
+    match seed {
+        ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => {
+            // iterate over the arguments to find out the length of the array
+            let num_rows = args[0..args.len() - 1]
+                .iter()
+                .find_map(|arg| match arg {
+                    ColumnarValue::Array(array) => Some(array.len()),
+                    ColumnarValue::Scalar(_) => None,
+                })
+                .unwrap_or(1);
+            let mut hashes: Vec<u32> = vec![0_u32; num_rows];
+            hashes.fill(*seed as u32);
+            let arrays = args[0..args.len() - 1]
+                .iter()
+                .map(|arg| match arg {
+                    ColumnarValue::Array(array) => array.clone(),
+                    ColumnarValue::Scalar(scalar) => {
+                        scalar.clone().to_array_of_size(num_rows).unwrap()
+                    }
+                })
+                .collect::<Vec<ArrayRef>>();
+            create_hashes(&arrays, &mut hashes)?;
+            if num_rows == 1 {
+                Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(
+                    hashes[0] as i32,
+                ))))
+            } else {
+                let hashes: Vec<i32> = hashes.into_iter().map(|x| x as 
i32).collect();
+                Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes))))
+            }
+        }
+        _ => {
+            internal_err!(
+                "The seed of function murmur3_hash must be an Int32 scalar 
value, but got: {:?}.",
+                seed
+            )
+        }
+    }
+}
+
+#[inline]
+fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
+    let mut s = String::with_capacity(data.as_ref().len() * 2);
+    for b in data.as_ref() {
+        // Writing to a string never errors, so we can unwrap here.
+        write!(&mut s, "{b:02x}").unwrap();
+    }
+    s
+}
+
+fn wrap_digest_result_as_hex_string(
+    args: &[ColumnarValue],
+    digest: ScalarFunctionImplementation,
+) -> Result<ColumnarValue, DataFusionError> {
+    let value = digest(args)?;
+    match value {
+        ColumnarValue::Array(array) => {
+            let binary_array = as_binary_array(&array)?;
+            let string_array: StringArray = binary_array
+                .iter()
+                .map(|opt| opt.map(hex_encode::<_>))
+                .collect();
+            Ok(ColumnarValue::Array(Arc::new(string_array)))
+        }
+        ColumnarValue::Scalar(ScalarValue::Binary(opt)) => 
Ok(ColumnarValue::Scalar(
+            ScalarValue::Utf8(opt.map(hex_encode::<_>)),
+        )),
+        _ => {
+            exec_err!(
+                "digest function should return binary value, but got: {:?}",
+                value.data_type()
+            )
+        }
+    }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index d08fb6b..57b15e2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1613,10 +1613,9 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
           optExprWithInfo(optExpr, expr, castExpr)
 
         case Md5(child) =>
-          val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
+          val childExpr = exprToProtoInternal(child, inputs)
           val optExpr = scalarExprToProto("md5", childExpr)
-          optExprWithInfo(optExpr, expr, castExpr)
+          optExprWithInfo(optExpr, expr, child)
 
         case OctetLength(child) =>
           val castExpr = Cast(child, StringType)
@@ -1954,6 +1953,44 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
             None
           }
 
+        case Murmur3Hash(children, seed) =>
+          val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
+          if (firstUnSupportedInput.isDefined) {
+            withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
+            return None
+          }
+          val exprs = children.map(exprToProtoInternal(_, inputs))
+          val seedBuilder = ExprOuterClass.Literal
+            .newBuilder()
+            .setDatatype(serializeDataType(IntegerType).get)
+            .setIntVal(seed)
+          val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+          // the seed is put at the end of the arguments
+          scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs 
:+ seedExpr: _*)
+
+        case Sha2(left, numBits) =>
+          if (!numBits.foldable) {
+            withInfo(expr, "non literal numBits is not supported")
+            return None
+          }
+          // it's possible for spark to dynamically compute the number of bits 
from input
+          // expression, however DataFusion does not support that yet.
+          val childExpr = exprToProtoInternal(left, inputs)
+          val bits = numBits.eval().asInstanceOf[Int]
+          val algorithm = bits match {
+            case 224 => "sha224"
+            case 256 | 0 => "sha256"
+            case 384 => "sha384"
+            case 512 => "sha512"
+            case _ =>
+              null
+          }
+          if (algorithm == null) {
+            exprToProtoInternal(Literal(null, StringType), inputs)
+          } else {
+            scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
+          }
+
         case _ =>
           withInfo(expr, s"${expr.prettyName} is not supported", 
expr.children: _*)
           None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 376baa3..3683c8d 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -981,8 +981,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  // TODO: enable this when we add md5 function to Comet
-  ignore("md5") {
+  test("md5") {
     Seq(false, true).foreach { dictionary =>
       withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
         val table = "test"
@@ -1405,4 +1404,27 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("hash functions") {
+    Seq(true, false).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "test"
+        withTable(table) {
+          sql(s"create table $table(col string, a int, b float) using parquet")
+          sql(s"""
+             |insert into $table values
+             |('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
+             |, ('Spark SQL  ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), 
('苹果手机', NULL, 3.999999)
+             |""".stripMargin)
+          checkSparkAnswerAndOperator("""
+               |select
+               |md5(col), md5(cast(a as string)), md5(cast(b as string)),
+               |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), 
hash(b, a, col),
+               |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), 
sha2(col, 512), sha2(col, 128)
+               |from test
+               |""".stripMargin)
+        }
+      }
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to