This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 457d9d11 fix: Optimize read_side_padding (#772)
457d9d11 is described below

commit 457d9d11b55b3ae923657b91259098397f7b3619
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Thu Aug 8 09:48:00 2024 -0700

    fix: Optimize read_side_padding (#772)
    
    ## Which issue does this PR close?
    
    ## Rationale for this change
    
    This PR improves read_side_padding that is used for CHAR() schema
    
    ## What changes are included in this PR?
    
    Optimized spark_read_side_padding
    
    ## How are these changes tested?
    
    Added tests
---
 native/Cargo.lock                                  |  1 -
 .../datafusion/expressions/comet_scalar_funcs.rs   | 10 ++--
 native/core/src/execution/datafusion/planner.rs    | 15 ++++--
 native/spark-expr/Cargo.toml                       |  1 -
 native/spark-expr/src/scalar_funcs.rs              | 62 +++++++++++-----------
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  4 +-
 .../resources/tpcds-micro-benchmarks/char_type.sql |  7 +++
 .../org/apache/comet/CometExpressionSuite.scala    | 14 +++++
 .../sql/benchmark/CometTPCDSMicroBenchmark.scala   |  1 +
 9 files changed, 71 insertions(+), 44 deletions(-)

diff --git a/native/Cargo.lock b/native/Cargo.lock
index b34ed54b..8cb39f5b 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -942,7 +942,6 @@ dependencies = [
  "regex",
  "thiserror",
  "twox-hash",
- "unicode-segmentation",
 ]
 
 [[package]]
diff --git 
a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs 
b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
index 70cbdeba..1203f90d 100644
--- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
+++ b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
@@ -21,8 +21,8 @@ use 
datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
 };
 use datafusion_comet_spark_expr::scalar_funcs::{
     spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, 
spark_make_decimal,
-    spark_murmur3_hash, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value, spark_xxhash64,
-    SparkChrFunc,
+    spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, 
spark_unscaled_value,
+    spark_xxhash64, SparkChrFunc,
 };
 use datafusion_common::{DataFusionError, Result as DataFusionResult};
 use datafusion_expr::registry::FunctionRegistry;
@@ -67,9 +67,9 @@ pub fn create_comet_physical_fun(
         "floor" => {
             make_comet_scalar_udf!("floor", spark_floor, data_type)
         }
-        "rpad" => {
-            let func = Arc::new(spark_rpad);
-            make_comet_scalar_udf!("rpad", func, without data_type)
+        "read_side_padding" => {
+            let func = Arc::new(spark_read_side_padding);
+            make_comet_scalar_udf!("read_side_padding", func, without 
data_type)
         }
         "round" => {
             make_comet_scalar_udf!("round", spark_round, data_type)
diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index a16ceda8..b604e98b 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -1724,11 +1724,16 @@ impl PhysicalPlanner {
 
         let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) 
{
             Some(t) => t,
-            None => self
-                .session_ctx
-                .udf(fun_name)?
-                .inner()
-                .return_type(&input_expr_types)?,
+            None => {
+                let fun_name = match fun_name.as_str() {
+                    "read_side_padding" => "rpad", // use the same return type 
as rpad
+                    other => other,
+                };
+                self.session_ctx
+                    .udf(fun_name)?
+                    .inner()
+                    .return_type(&input_expr_types)?
+            }
         };
 
         let fun_expr =
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 96eae39f..1a8c8aeb 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
 num = { workspace = true }
 regex = { workspace = true }
 thiserror = { workspace = true }
-unicode-segmentation = "1.11.0"
 
 [dev-dependencies]
 arrow-data = {workspace = true}
diff --git a/native/spark-expr/src/scalar_funcs.rs 
b/native/spark-expr/src/scalar_funcs.rs
index 7cbaf12a..ffd6fd21 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -15,15 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{cmp::min, sync::Arc};
-
 use arrow::{
     array::{
-        ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, 
GenericStringArray,
-        Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, 
OffsetSizeTrait,
+        ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, 
Int16Array, Int32Array,
+        Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
     },
     datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
 };
+use arrow_array::builder::GenericStringBuilder;
 use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
 use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
 use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
@@ -35,7 +34,8 @@ use num::{
     integer::{div_ceil, div_floor},
     BigInt, Signed, ToPrimitive,
 };
-use unicode_segmentation::UnicodeSegmentation;
+use std::fmt::Write;
+use std::{cmp::min, sync::Arc};
 
 mod unhex;
 pub use unhex::spark_unhex;
@@ -387,52 +387,54 @@ pub fn spark_round(
 }
 
 /// Similar to DataFusion `rpad`, but not to truncate when the string is 
already longer than length
-pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+pub fn spark_read_side_padding(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFusionError> {
     match args {
         [ColumnarValue::Array(array), 
ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
-            match args[0].data_type() {
-                DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
-                DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, 
*length),
+            match array.data_type() {
+                DataType::Utf8 => 
spark_read_side_padding_internal::<i32>(array, *length),
+                DataType::LargeUtf8 => 
spark_read_side_padding_internal::<i64>(array, *length),
                 // TODO: handle Dictionary types
                 other => Err(DataFusionError::Internal(format!(
-                    "Unsupported data type {other:?} for function rpad",
+                    "Unsupported data type {other:?} for function 
read_side_padding",
                 ))),
             }
         }
         other => Err(DataFusionError::Internal(format!(
-            "Unsupported arguments {other:?} for function rpad",
+            "Unsupported arguments {other:?} for function read_side_padding",
         ))),
     }
 }
 
-fn spark_rpad_internal<T: OffsetSizeTrait>(
+fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
     array: &ArrayRef,
     length: i32,
 ) -> Result<ColumnarValue, DataFusionError> {
     let string_array = as_generic_string_array::<T>(array)?;
+    let length = 0.max(length) as usize;
+    let space_string = " ".repeat(length);
+
+    let mut builder =
+        GenericStringBuilder::<T>::with_capacity(string_array.len(), 
string_array.len() * length);
 
-    let result = string_array
-        .iter()
-        .map(|string| match string {
+    for string in string_array.iter() {
+        match string {
             Some(string) => {
-                let length = if length < 0 { 0 } else { length as usize };
-                if length == 0 {
-                    Ok(Some("".to_string()))
+                // It looks Spark's UTF8String is closer to chars rather than 
graphemes
+                // https://stackoverflow.com/a/46290728
+                let char_len = string.chars().count();
+                if length <= char_len {
+                    builder.append_value(string);
                 } else {
-                    let graphemes = 
string.graphemes(true).collect::<Vec<&str>>();
-                    if length < graphemes.len() {
-                        Ok(Some(string.to_string()))
-                    } else {
-                        let mut s = string.to_string();
-                        s.push_str(" ".repeat(length - 
graphemes.len()).as_str());
-                        Ok(Some(s))
-                    }
+                    // write_str updates only the value buffer, not null nor 
offset buffer
+                    // This is convenient for concatenating str(s)
+                    builder.write_str(string)?;
+                    builder.append_value(&space_string[char_len..]);
                 }
             }
-            _ => Ok(None),
-        })
-        .collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
-    Ok(ColumnarValue::Array(Arc::new(result)))
+            _ => builder.append_null(),
+        }
+    }
+    Ok(ColumnarValue::Array(Arc::new(builder.finish())))
 }
 
 // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = 
Decimal(p3, s3).
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8f08eeba..5f3cc7a2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           }
 
         // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
-        // char types. Use rpad to achieve the behavior.
+        // char types.
         // See https://github.com/apache/spark/pull/38151
         case s: StaticInvoke
             if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
@@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
           if (argsExpr.forall(_.isDefined)) {
             val builder = ExprOuterClass.ScalarFunc.newBuilder()
-            builder.setFunc("rpad")
+            builder.setFunc("read_side_padding")
             argsExpr.foreach(arg => builder.addArgs(arg.get))
 
             
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
diff --git a/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql 
b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
new file mode 100644
index 00000000..8a5359d4
--- /dev/null
+++ b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
@@ -0,0 +1,7 @@
+SELECT
+    cd_gender
+FROM customer_demographics
+WHERE
+    cd_gender = 'M' AND
+    cd_marital_status = 'S' AND
+    cd_education_status = 'College'
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index cce48719..ded5bc5c 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("readSidePadding") {
+    // https://stackoverflow.com/a/46290728
+    val table = "test"
+    withTable(table) {
+      sql(s"create table $table(col1 CHAR(2)) using parquet")
+      sql(s"insert into $table values('é')") // unicode 'e\\u{301}'
+      sql(s"insert into $table values('é')") // unicode '\\u{e9}'
+      sql(s"insert into $table values('')")
+      sql(s"insert into $table values('ab')")
+
+      checkSparkAnswerAndOperator(s"SELECT * FROM $table")
+    }
+  }
+
   test("isnan") {
     Seq("true", "false").foreach { dictionary =>
       withSQLConf("parquet.enable.dictionary" -> dictionary) {
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
index b09e0486..aa0c9115 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala
@@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends 
CometTPCQueryBenchmarkBase {
     "agg_sum_integers_no_grouping",
     "case_when_column_or_null",
     "case_when_scalar",
+    "char_type",
     "filter_highly_selective",
     "filter_less_selective",
     "if_column_or_null",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to