This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a26b3de91 feat: Add support for rpad (#1470)
a26b3de91 is described below
commit a26b3de9108f4ea68f8c6acab7314d7faf620c78
Author: Andy Grove <[email protected]>
AuthorDate: Wed Mar 5 18:15:55 2025 -0700
feat: Add support for rpad (#1470)
* use stable toolchain
* clippy
* fmt
* add support for lpad and rpad
* test passse
* enable read-side padding in TPC stability suite
* format
* revert a change
* address feeedback
* re-implement
* re-implement
* re-implement
* format
* address feedback
---
docs/source/user-guide/expressions.md | 3 ++-
native/spark-expr/src/comet_scalar_funcs.rs | 6 ++++-
.../src/static_invoke/char_varchar_utils/mod.rs | 2 +-
.../char_varchar_utils/read_side_padding.rs | 30 +++++++++++++++++++---
native/spark-expr/src/static_invoke/mod.rs | 2 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 24 +++++++++++++----
.../org/apache/comet/CometExpressionSuite.scala | 27 +++++++++++++++++++
7 files changed, 82 insertions(+), 12 deletions(-)
diff --git a/docs/source/user-guide/expressions.md
b/docs/source/user-guide/expressions.md
index 853d814c0..e43da50d7 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -68,7 +68,7 @@ The following Spark expressions are currently available. Any
known compatibility
## String Functions
| Expression | Notes
|
-| --------------- |
-----------------------------------------------------------------------------------------------------------
|
+|-----------------|
-----------------------------------------------------------------------------------------------------------
|
| Ascii |
|
| BitLength |
|
| Chr |
|
@@ -85,6 +85,7 @@ The following Spark expressions are currently available. Any
known compatibility
| Replace |
|
| Reverse |
|
| StartsWith |
|
+| StringRPad |
|
| StringSpace |
|
| StringTrim |
|
| StringTrimBoth |
|
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 227b6f72e..d77feb8d3 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -19,7 +19,7 @@ use crate::hash_funcs::*;
use crate::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_decimal_integral_div,
spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_read_side_padding, spark_round,
- spark_unhex, spark_unscaled_value, SparkChrFunc,
+ spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc,
};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
@@ -69,6 +69,10 @@ pub fn create_comet_physical_fun(
let func = Arc::new(spark_read_side_padding);
make_comet_scalar_udf!("read_side_padding", func, without
data_type)
}
+ "rpad" => {
+ let func = Arc::new(spark_rpad);
+ make_comet_scalar_udf!("rpad", func, without data_type)
+ }
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
}
diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
index fff6134da..0a8d8f3c5 100644
--- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
+++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
@@ -17,4 +17,4 @@
mod read_side_padding;
-pub use read_side_padding::spark_read_side_padding;
+pub use read_side_padding::{spark_read_side_padding, spark_rpad};
diff --git
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index 15807bf57..1f9400b35 100644
---
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -26,11 +26,25 @@ use std::sync::Arc;
/// Similar to DataFusion `rpad`, but not to truncate when the string is
already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) ->
Result<ColumnarValue, DataFusionError> {
+ spark_read_side_padding2(args, false)
+}
+
+/// Custom `rpad` because DataFusion's `rpad` has differences in unicode
handling
+pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ spark_read_side_padding2(args, true)
+}
+
+fn spark_read_side_padding2(
+ args: &[ColumnarValue],
+ truncate: bool,
+) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
- DataType::Utf8 =>
spark_read_side_padding_internal::<i32>(array, *length),
- DataType::LargeUtf8 =>
spark_read_side_padding_internal::<i64>(array, *length),
+ DataType::Utf8 =>
spark_read_side_padding_internal::<i32>(array, *length, truncate),
+ DataType::LargeUtf8 => {
+ spark_read_side_padding_internal::<i64>(array, *length,
truncate)
+ }
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function
read_side_padding",
@@ -46,6 +60,7 @@ pub fn spark_read_side_padding(args: &[ColumnarValue]) ->
Result<ColumnarValue,
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
+ truncate: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
@@ -61,7 +76,16 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
- builder.append_value(string);
+ if truncate {
+ let idx = string
+ .char_indices()
+ .nth(length)
+ .map(|(i, _)| i)
+ .unwrap_or(string.len());
+ builder.append_value(&string[..idx]);
+ } else {
+ builder.append_value(string);
+ }
} else {
// write_str updates only the value buffer, not null nor
offset buffer
// This is convenient for concatenating str(s)
diff --git a/native/spark-expr/src/static_invoke/mod.rs
b/native/spark-expr/src/static_invoke/mod.rs
index 4072e13b7..39735f156 100644
--- a/native/spark-expr/src/static_invoke/mod.rs
+++ b/native/spark-expr/src/static_invoke/mod.rs
@@ -17,4 +17,4 @@
mod char_varchar_utils;
-pub use char_varchar_utils::spark_read_side_padding;
+pub use char_varchar_utils::{spark_read_side_padding, spark_rpad};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8757105ec..f50a6606c 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1745,16 +1745,30 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
exprToProtoInternal(s.arguments(1), inputs, binding))
if (argsExpr.forall(_.isDefined)) {
- val builder = ExprOuterClass.ScalarFunc.newBuilder()
- builder.setFunc("read_side_padding")
- argsExpr.foreach(arg => builder.addArgs(arg.get))
-
- Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
+ scalarExprToProto("read_side_padding", argsExpr: _*)
} else {
withInfo(expr, s.arguments: _*)
None
}
+ // read-side padding in Spark 3.5.2+ is represented by rpad function
+ case StringRPad(srcStr, size, chars) =>
+ chars match {
+ case Literal(str, DataTypes.StringType) if str.toString == " " =>
+ val arg0 = exprToProtoInternal(srcStr, inputs, binding)
+ val arg1 = exprToProtoInternal(size, inputs, binding)
+ if (arg0.isDefined && arg1.isDefined) {
+ scalarExprToProto("rpad", arg0, arg1)
+ } else {
+ withInfo(expr, "rpad unsupported arguments", srcStr, size)
+ None
+ }
+
+ case _ =>
+ withInfo(expr, "rpad only supports padding with spaces")
+ None
+ }
+
case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
val dataType = serializeDataType(expr.dataType)
if (dataType.isEmpty) {
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 8b6a0fd75..e9b42b73a 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2122,6 +2122,33 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("rpad") {
+ val table = "rpad"
+ val gen = new DataGenerator(new Random(42))
+ withTable(table) {
+ // generate some data
+ val dataChars = "abc123"
+ sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using
parquet")
+ val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
+ "é", // unicode 'e\\u{301}'
+ "é" // unicode '\\u{e9}'
+ )
+ testData.zipWithIndex.foreach { x =>
+ sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
+ }
+ // test 2-arg version
+ checkSparkAnswerAndOperator(
+ s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
+ // test 3-arg version
+ for (length <- Seq(2, 10)) {
+ checkSparkAnswerAndOperator(
+ s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY
id")
+ checkSparkAnswerAndOperator(
+ s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY
id")
+ }
+ }
+ }
+
test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]