This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 8485558 feat: Support murmur3_hash and sha2 family hash functions
(#226)
8485558 is described below
commit 848555818713ca9dff6fedfdf6e407969f1eeca2
Author: advancedxy <[email protected]>
AuthorDate: Fri Apr 26 13:02:47 2024 +0800
feat: Support murmur3_hash and sha2 family hash functions (#226)
* feat: Support murmur3_hash and sha2 family hash functions
* address comments
* apply scalafix
* ensure crypto_expressions feature is enabled
---
core/Cargo.toml | 4 +-
.../datafusion/expressions/scalar_funcs.rs | 206 ++++++++++++++-------
.../org/apache/comet/serde/QueryPlanSerde.scala | 43 ++++-
.../org/apache/comet/CometExpressionSuite.scala | 26 ++-
4 files changed, 205 insertions(+), 74 deletions(-)
diff --git a/core/Cargo.toml b/core/Cargo.toml
index 5d16049..b09b0ea 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -67,8 +67,8 @@ chrono = { version = "0.4", default-features = false,
features = ["clock"] }
chrono-tz = { version = "0.8" }
paste = "1.0.14"
datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git",
rev = "57b3be4" }
-datafusion = { default-features = false, git =
"https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features =
["unicode_expressions"] }
-datafusion-functions = { git =
"https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" }
+datafusion = { default-features = false, git =
"https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features =
["unicode_expressions", "crypto_expressions"] }
+datafusion-functions = { git =
"https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features =
["crypto_expressions"]}
datafusion-physical-expr = { git =
"https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4",
default-features = false, features = ["unicode_expressions"] }
unicode-segmentation = "^1.10.1"
once_cell = "1.18.0"
diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs
b/core/src/execution/datafusion/expressions/scalar_funcs.rs
index e6f8de1..2895937 100644
--- a/core/src/execution/datafusion/expressions/scalar_funcs.rs
+++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs
@@ -15,8 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use std::{any::Any, cmp::min, fmt::Debug, str::FromStr, sync::Arc};
+use std::{
+ any::Any,
+ cmp::min,
+ fmt::{Debug, Write},
+ str::FromStr,
+ sync::Arc,
+};
+use crate::execution::datafusion::spark_hash::create_hashes;
use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array,
GenericStringArray,
@@ -24,7 +31,7 @@ use arrow::{
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
-use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
+use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray};
use arrow_schema::DataType;
use datafusion::{
execution::FunctionRegistry,
@@ -35,8 +42,8 @@ use datafusion::{
physical_plan::ColumnarValue,
};
use datafusion_common::{
- cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
- Result as DataFusionResult, ScalarValue,
+ cast::{as_binary_array, as_generic_string_array},
+ exec_err, internal_err, DataFusionError, Result as DataFusionResult,
ScalarValue,
};
use datafusion_physical_expr::{math_expressions, udf::ScalarUDF};
use num::{
@@ -45,89 +52,75 @@ use num::{
};
use unicode_segmentation::UnicodeSegmentation;
+macro_rules! make_comet_scalar_udf {
+ ($name:expr, $func:ident, $data_type:ident) => {{
+ let scalar_func = CometScalarFunction::new(
+ $name.to_string(),
+ Signature::variadic_any(Volatility::Immutable),
+ $data_type.clone(),
+ Arc::new(move |args| $func(args, &$data_type)),
+ );
+ Ok(ScalarFunctionDefinition::UDF(Arc::new(
+ ScalarUDF::new_from_impl(scalar_func),
+ )))
+ }};
+ ($name:expr, $func:expr, without $data_type:ident) => {{
+ let scalar_func = CometScalarFunction::new(
+ $name.to_string(),
+ Signature::variadic_any(Volatility::Immutable),
+ $data_type,
+ $func,
+ );
+ Ok(ScalarFunctionDefinition::UDF(Arc::new(
+ ScalarUDF::new_from_impl(scalar_func),
+ )))
+ }};
+}
+
/// Create a physical scalar function.
pub fn create_comet_physical_fun(
fun_name: &str,
data_type: DataType,
registry: &dyn FunctionRegistry,
) -> Result<ScalarFunctionDefinition, DataFusionError> {
+ let sha2_functions = ["sha224", "sha256", "sha384", "sha512"];
match fun_name {
"ceil" => {
- let scalar_func = CometScalarFunction::new(
- "ceil".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(move |args| spark_ceil(args, &data_type)),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ make_comet_scalar_udf!("ceil", spark_ceil, data_type)
}
"floor" => {
- let scalar_func = CometScalarFunction::new(
- "floor".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(move |args| spark_floor(args, &data_type)),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ make_comet_scalar_udf!("floor", spark_floor, data_type)
}
"rpad" => {
- let scalar_func = CometScalarFunction::new(
- "rpad".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(spark_rpad),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ let func = Arc::new(spark_rpad);
+ make_comet_scalar_udf!("rpad", func, without data_type)
}
"round" => {
- let scalar_func = CometScalarFunction::new(
- "round".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(move |args| spark_round(args, &data_type)),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ make_comet_scalar_udf!("round", spark_round, data_type)
}
"unscaled_value" => {
- let scalar_func = CometScalarFunction::new(
- "unscaled_value".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(spark_unscaled_value),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ let func = Arc::new(spark_unscaled_value);
+ make_comet_scalar_udf!("unscaled_value", func, without data_type)
}
"make_decimal" => {
- let scalar_func = CometScalarFunction::new(
- "make_decimal".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(move |args| spark_make_decimal(args, &data_type)),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ make_comet_scalar_udf!("make_decimal", spark_make_decimal,
data_type)
}
"decimal_div" => {
- let scalar_func = CometScalarFunction::new(
- "decimal_div".to_string(),
- Signature::variadic_any(Volatility::Immutable),
- data_type.clone(),
- Arc::new(move |args| spark_decimal_div(args, &data_type)),
- );
- Ok(ScalarFunctionDefinition::UDF(Arc::new(
- ScalarUDF::new_from_impl(scalar_func),
- )))
+ make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
+ }
+ "murmur3_hash" => {
+ let func = Arc::new(spark_murmur3_hash);
+ make_comet_scalar_udf!("murmur3_hash", func, without data_type)
+ }
+ sha if sha2_functions.contains(&sha) => {
+ // Spark requires hex string as the result of sha2 functions, we
have to wrap the
+ // result of digest functions as hex string
+ let func = registry.udf(sha)?;
+ let wrapped_func = Arc::new(move |args: &[ColumnarValue]| {
+ wrap_digest_result_as_hex_string(args, func.fun())
+ });
+ let spark_func_name = "spark".to_owned() + sha;
+ make_comet_scalar_udf!(spark_func_name, wrapped_func, without
data_type)
}
_ => {
let fun = BuiltinScalarFunction::from_str(fun_name);
@@ -629,3 +622,82 @@ fn spark_decimal_div(
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}
+
+fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ let length = args.len();
+ let seed = &args[length - 1];
+ match seed {
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => {
+ // iterate over the arguments to find out the length of the array
+ let num_rows = args[0..args.len() - 1]
+ .iter()
+ .find_map(|arg| match arg {
+ ColumnarValue::Array(array) => Some(array.len()),
+ ColumnarValue::Scalar(_) => None,
+ })
+ .unwrap_or(1);
+ let mut hashes: Vec<u32> = vec![0_u32; num_rows];
+ hashes.fill(*seed as u32);
+ let arrays = args[0..args.len() - 1]
+ .iter()
+ .map(|arg| match arg {
+ ColumnarValue::Array(array) => array.clone(),
+ ColumnarValue::Scalar(scalar) => {
+ scalar.clone().to_array_of_size(num_rows).unwrap()
+ }
+ })
+ .collect::<Vec<ArrayRef>>();
+ create_hashes(&arrays, &mut hashes)?;
+ if num_rows == 1 {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(
+ hashes[0] as i32,
+ ))))
+ } else {
+ let hashes: Vec<i32> = hashes.into_iter().map(|x| x as
i32).collect();
+ Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes))))
+ }
+ }
+ _ => {
+ internal_err!(
+ "The seed of function murmur3_hash must be an Int32 scalar
value, but got: {:?}.",
+ seed
+ )
+ }
+ }
+}
+
+#[inline]
+fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
+ let mut s = String::with_capacity(data.as_ref().len() * 2);
+ for b in data.as_ref() {
+ // Writing to a string never errors, so we can unwrap here.
+ write!(&mut s, "{b:02x}").unwrap();
+ }
+ s
+}
+
+fn wrap_digest_result_as_hex_string(
+ args: &[ColumnarValue],
+ digest: ScalarFunctionImplementation,
+) -> Result<ColumnarValue, DataFusionError> {
+ let value = digest(args)?;
+ match value {
+ ColumnarValue::Array(array) => {
+ let binary_array = as_binary_array(&array)?;
+ let string_array: StringArray = binary_array
+ .iter()
+ .map(|opt| opt.map(hex_encode::<_>))
+ .collect();
+ Ok(ColumnarValue::Array(Arc::new(string_array)))
+ }
+ ColumnarValue::Scalar(ScalarValue::Binary(opt)) =>
Ok(ColumnarValue::Scalar(
+ ScalarValue::Utf8(opt.map(hex_encode::<_>)),
+ )),
+ _ => {
+ exec_err!(
+ "digest function should return binary value, but got: {:?}",
+ value.data_type()
+ )
+ }
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index d08fb6b..57b15e2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1613,10 +1613,9 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
optExprWithInfo(optExpr, expr, castExpr)
case Md5(child) =>
- val castExpr = Cast(child, StringType)
- val childExpr = exprToProtoInternal(castExpr, inputs)
+ val childExpr = exprToProtoInternal(child, inputs)
val optExpr = scalarExprToProto("md5", childExpr)
- optExprWithInfo(optExpr, expr, castExpr)
+ optExprWithInfo(optExpr, expr, child)
case OctetLength(child) =>
val castExpr = Cast(child, StringType)
@@ -1954,6 +1953,44 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
None
}
+ case Murmur3Hash(children, seed) =>
+ val firstUnSupportedInput = children.find(c =>
!supportedDataType(c.dataType))
+ if (firstUnSupportedInput.isDefined) {
+ withInfo(expr, s"Unsupported datatype
${firstUnSupportedInput.get.dataType}")
+ return None
+ }
+ val exprs = children.map(exprToProtoInternal(_, inputs))
+ val seedBuilder = ExprOuterClass.Literal
+ .newBuilder()
+ .setDatatype(serializeDataType(IntegerType).get)
+ .setIntVal(seed)
+ val seedExpr =
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+ // the seed is put at the end of the arguments
+ scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs
:+ seedExpr: _*)
+
+ case Sha2(left, numBits) =>
+ if (!numBits.foldable) {
+ withInfo(expr, "non literal numBits is not supported")
+ return None
+ }
+ // it's possible for spark to dynamically compute the number of bits
from input
+ // expression, however DataFusion does not support that yet.
+ val childExpr = exprToProtoInternal(left, inputs)
+ val bits = numBits.eval().asInstanceOf[Int]
+ val algorithm = bits match {
+ case 224 => "sha224"
+ case 256 | 0 => "sha256"
+ case 384 => "sha384"
+ case 512 => "sha512"
+ case _ =>
+ null
+ }
+ if (algorithm == null) {
+ exprToProtoInternal(Literal(null, StringType), inputs)
+ } else {
+ scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
+ }
+
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported",
expr.children: _*)
None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 376baa3..3683c8d 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -981,8 +981,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
- // TODO: enable this when we add md5 function to Comet
- ignore("md5") {
+ test("md5") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
@@ -1405,4 +1404,27 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("hash functions") {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(s"create table $table(col string, a int, b float) using parquet")
+ sql(s"""
+ |insert into $table values
+ |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0),
('苹果手机', NULL, 3.999999)
+ |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0),
('苹果手机', NULL, 3.999999)
+ |""".stripMargin)
+ checkSparkAnswerAndOperator("""
+ |select
+ |md5(col), md5(cast(a as string)), md5(cast(b as string)),
+ |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b),
hash(b, a, col),
+ |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384),
sha2(col, 512), sha2(col, 128)
+ |from test
+ |""".stripMargin)
+ }
+ }
+ }
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]