This is an automated email from the ASF dual-hosted git repository.
kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 457d9d11 fix: Optimize read_side_padding (#772)
457d9d11 is described below
commit 457d9d11b55b3ae923657b91259098397f7b3619
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Thu Aug 8 09:48:00 2024 -0700
fix: Optimize read_side_padding (#772)
## Which issue does this PR close?
## Rationale for this change
This PR improves read_side_padding that is used for CHAR() schema
## What changes are included in this PR?
Optimized spark_read_side_padding
## How are these changes tested?
Added tests
---
native/Cargo.lock | 1 -
.../datafusion/expressions/comet_scalar_funcs.rs | 10 ++--
native/core/src/execution/datafusion/planner.rs | 15 ++++--
native/spark-expr/Cargo.toml | 1 -
native/spark-expr/src/scalar_funcs.rs | 62 +++++++++++-----------
.../org/apache/comet/serde/QueryPlanSerde.scala | 4 +-
.../resources/tpcds-micro-benchmarks/char_type.sql | 7 +++
.../org/apache/comet/CometExpressionSuite.scala | 14 +++++
.../sql/benchmark/CometTPCDSMicroBenchmark.scala | 1 +
9 files changed, 71 insertions(+), 44 deletions(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index b34ed54b..8cb39f5b 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -942,7 +942,6 @@ dependencies = [
"regex",
"thiserror",
"twox-hash",
- "unicode-segmentation",
]
[[package]]
diff --git
a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
index 70cbdeba..1203f90d 100644
--- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
+++ b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
@@ -21,8 +21,8 @@ use
datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
};
use datafusion_comet_spark_expr::scalar_funcs::{
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal,
- spark_murmur3_hash, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, spark_xxhash64,
- SparkChrFunc,
+ spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex,
spark_unscaled_value,
+ spark_xxhash64, SparkChrFunc,
};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
@@ -67,9 +67,9 @@ pub fn create_comet_physical_fun(
"floor" => {
make_comet_scalar_udf!("floor", spark_floor, data_type)
}
- "rpad" => {
- let func = Arc::new(spark_rpad);
- make_comet_scalar_udf!("rpad", func, without data_type)
+ "read_side_padding" => {
+ let func = Arc::new(spark_read_side_padding);
+ make_comet_scalar_udf!("read_side_padding", func, without
data_type)
}
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index a16ceda8..b604e98b 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -1724,11 +1724,16 @@ impl PhysicalPlanner {
let data_type = match expr.return_type.as_ref().map(to_arrow_datatype)
{
Some(t) => t,
- None => self
- .session_ctx
- .udf(fun_name)?
- .inner()
- .return_type(&input_expr_types)?,
+ None => {
+ let fun_name = match fun_name.as_str() {
+ "read_side_padding" => "rpad", // use the same return type
as rpad
+ other => other,
+ };
+ self.session_ctx
+ .udf(fun_name)?
+ .inner()
+ .return_type(&input_expr_types)?
+ }
};
let fun_expr =
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 96eae39f..1a8c8aeb 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
-unicode-segmentation = "1.11.0"
[dev-dependencies]
arrow-data = {workspace = true}
diff --git a/native/spark-expr/src/scalar_funcs.rs
b/native/spark-expr/src/scalar_funcs.rs
index 7cbaf12a..ffd6fd21 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.
-use std::{cmp::min, sync::Arc};
-
use arrow::{
array::{
- ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array,
GenericStringArray,
- Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array,
OffsetSizeTrait,
+ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array,
Int16Array, Int32Array,
+ Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
+use arrow_array::builder::GenericStringBuilder;
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
@@ -35,7 +34,8 @@ use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
};
-use unicode_segmentation::UnicodeSegmentation;
+use std::fmt::Write;
+use std::{cmp::min, sync::Arc};
mod unhex;
pub use unhex::spark_unhex;
@@ -387,52 +387,54 @@ pub fn spark_round(
}
/// Similar to DataFusion `rpad`, but not to truncate when the string is
already longer than length
-pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+pub fn spark_read_side_padding(args: &[ColumnarValue]) ->
Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
- match args[0].data_type() {
- DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
- DataType::LargeUtf8 => spark_rpad_internal::<i64>(array,
*length),
+ match array.data_type() {
+ DataType::Utf8 =>
spark_read_side_padding_internal::<i32>(array, *length),
+ DataType::LargeUtf8 =>
spark_read_side_padding_internal::<i64>(array, *length),
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {other:?} for function rpad",
+ "Unsupported data type {other:?} for function
read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
- "Unsupported arguments {other:?} for function rpad",
+ "Unsupported arguments {other:?} for function read_side_padding",
))),
}
}
-fn spark_rpad_internal<T: OffsetSizeTrait>(
+fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
+ let length = 0.max(length) as usize;
+ let space_string = " ".repeat(length);
+
+ let mut builder =
+ GenericStringBuilder::<T>::with_capacity(string_array.len(),
string_array.len() * length);
- let result = string_array
- .iter()
- .map(|string| match string {
+ for string in string_array.iter() {
+ match string {
Some(string) => {
- let length = if length < 0 { 0 } else { length as usize };
- if length == 0 {
- Ok(Some("".to_string()))
+ // It looks Spark's UTF8String is closer to chars rather than
graphemes
+ // https://stackoverflow.com/a/46290728
+ let char_len = string.chars().count();
+ if length <= char_len {
+ builder.append_value(string);
} else {
- let graphemes =
string.graphemes(true).collect::<Vec<&str>>();
- if length < graphemes.len() {
- Ok(Some(string.to_string()))
- } else {
- let mut s = string.to_string();
- s.push_str(" ".repeat(length -
graphemes.len()).as_str());
- Ok(Some(s))
- }
+ // write_str updates only the value buffer, not null nor
offset buffer
+ // This is convenient for concatenating str(s)
+ builder.write_str(string)?;
+ builder.append_value(&space_string[char_len..]);
}
}
- _ => Ok(None),
- })
- .collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
- Ok(ColumnarValue::Array(Arc::new(result)))
+ _ => builder.append_null(),
+ }
+ }
+ Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) =
Decimal(p3, s3).
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8f08eeba..5f3cc7a2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called
to pad spaces for
- // char types. Use rpad to achieve the behavior.
+ // char types.
// See https://github.com/apache/spark/pull/38151
case s: StaticInvoke
if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
@@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
if (argsExpr.forall(_.isDefined)) {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
- builder.setFunc("rpad")
+ builder.setFunc("read_side_padding")
argsExpr.foreach(arg => builder.addArgs(arg.get))
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
diff --git a/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
new file mode 100644
index 00000000..8a5359d4
--- /dev/null
+++ b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
@@ -0,0 +1,7 @@
+SELECT
+ cd_gender
+FROM customer_demographics
+WHERE
+ cd_gender = 'M' AND
+ cd_marital_status = 'S' AND
+ cd_education_status = 'College'
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index cce48719..ded5bc5c 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("readSidePadding") {
+ // https://stackoverflow.com/a/46290728
+ val table = "test"
+ withTable(table) {
+ sql(s"create table $table(col1 CHAR(2)) using parquet")
+ sql(s"insert into $table values('é')") // unicode 'e\\u{301}'
+ sql(s"insert into $table values('é')") // unicode '\\u{e9}'
+ sql(s"insert into $table values('')")
+ sql(s"insert into $table values('ab')")
+
+ checkSparkAnswerAndOperator(s"SELECT * FROM $table")
+ }
+ }
+
test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
diff --git
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
index b09e0486..aa0c9115 100644
---
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
@@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends
CometTPCQueryBenchmarkBase {
"agg_sum_integers_no_grouping",
"case_when_column_or_null",
"case_when_scalar",
+ "char_type",
"filter_highly_selective",
"filter_less_selective",
"if_column_or_null",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]