This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8cba15ed30ea [SPARK-47483][SQL] Add support for aggregation and join 
operations on arrays of collated strings
8cba15ed30ea is described below

commit 8cba15ed30ea55185ebbc8d3601852381a4bfd97
Author: Nikola Mandic <nikola.man...@databricks.com>
AuthorDate: Fri Mar 22 12:15:05 2024 +0500

    [SPARK-47483][SQL] Add support for aggregation and join operations on 
arrays of collated strings
    
    ### What changes were proposed in this pull request?
    
    Example of aggregation sequence:
    ```
    create table t(a array<string collate utf8_binary_lcase>) using parquet;
    
    insert into t(a) values(array('a' collate utf8_binary_lcase));
    insert into t(a) values(array('A' collate utf8_binary_lcase));
    
    select distinct a from t;
    ```
    Example of join sequence:
    ```
    create table l(a array<string collate utf8_binary_lcase>) using parquet;
    create table r(a array<string collate utf8_binary_lcase>) using parquet;
    
    insert into l(a) values(array('a' collate utf8_binary_lcase));
    insert into r(a) values(array('A' collate utf8_binary_lcase));
    
    select * from l join r where l.a = r.a;
    ```
    Both runs should yield one row since the arrays are considered equal.
    
    Problem is in `isBinaryStable` function which should return false if 
**any** of its subtypes is non-binary collated string.
    
    ### Why are the changes needed?
    
    To support aggregates and joins in arrays of collated strings properly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it fixes the described scenarios.
    
    ### How was this patch tested?
    
    Added new checks to collation suite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45611 from nikolamand-db/SPARK-47483.
    
    Authored-by: Nikola Mandic <nikola.man...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../spark/sql/catalyst/util/UnsafeRowUtils.scala   |  6 +-
 .../sql/catalyst/util/UnsafeRowUtilsSuite.scala    | 68 ++++++++++++++++-
 .../org/apache/spark/sql/CollationSuite.scala      | 87 +++++++++++++++++++++-
 3 files changed, 156 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
index 0718cf110f75..0c1ce5ffa8b0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
@@ -204,8 +204,8 @@ object UnsafeRowUtils {
    * e.g. this is not true for non-binary collations (any case/accent 
insensitive collation
    * can lead to rows being semantically equal even though their binary 
representations differ).
    */
-  def isBinaryStable(dataType: DataType): Boolean = dataType.existsRecursively 
{
-    case st: StringType => 
CollationFactory.fetchCollation(st.collationId).isBinaryCollation
-    case _ => true
+  def isBinaryStable(dataType: DataType): Boolean = 
!dataType.existsRecursively {
+    case st: StringType => 
!CollationFactory.fetchCollation(st.collationId).isBinaryCollation
+    case _ => false
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala
index c7a8bc74f4dd..b6e87c456de0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala
@@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, 
UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, 
StringType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, 
IntegerType, MapType, StringType, StructField, StructType}
 
 class UnsafeRowUtilsSuite extends SparkFunSuite {
 
@@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite {
         "fieldStatus:\n" +
         "[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,"))
   }
+
+  test("isBinaryStable on complex types containing collated strings") {
+    val nonBinaryStringType = 
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))
+
+    // simple checks
+    assert(UnsafeRowUtils.isBinaryStable(IntegerType))
+    assert(UnsafeRowUtils.isBinaryStable(StringType))
+    assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType))
+
+    assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType)))
+    assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType)))
+    assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType)))
+
+    assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType)))
+    assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, 
StringType)))
+    assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, 
nonBinaryStringType)))
+    assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, 
nonBinaryStringType)))
+    assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, 
IntegerType)))
+    assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, 
nonBinaryStringType)))
+
+    assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", 
IntegerType) :: Nil)))
+    assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", 
StringType) :: Nil)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", nonBinaryStringType) :: Nil)))
+
+    // nested complex types
+    assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType))))
+    assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, 
IntegerType))))
+    assert(UnsafeRowUtils.isBinaryStable(
+      ArrayType(StructType(StructField("field", StringType) :: Nil))))
+    
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType))))
+    assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, 
nonBinaryStringType))))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType)))))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil))))
+    assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType(
+      Seq(StructField("second", IntegerType), StructField("second", 
nonBinaryStringType))))))
+
+    assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), 
ArrayType(IntegerType))))
+    assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, 
StringType), IntegerType)))
+    assert(UnsafeRowUtils.isBinaryStable(
+      MapType(StructType(StructField("field", StringType) :: Nil), 
IntegerType)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType))))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      MapType(IntegerType, ArrayType(nonBinaryStringType))))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      MapType(MapType(IntegerType, nonBinaryStringType), IntegerType)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), 
IntegerType)))
+
+    assert(UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", ArrayType(IntegerType)) :: Nil)))
+    assert(UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", MapType(StringType, IntegerType)) :: 
Nil)))
+    assert(UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", StructType(StructField("sub", 
IntegerType) :: Nil)) :: Nil)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field", MapType(nonBinaryStringType, 
IntegerType)) :: Nil)))
+    assert(!UnsafeRowUtils.isBinaryStable(
+      StructType(StructField("field",
+        StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil)))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index efb3c2f8ba8e..146ba63cf402 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -27,10 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, 
FakeV2ProviderWithCu
 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
 import 
org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
-import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.types.{MapType, StringType, StructField, 
StructType}
 
 class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
   protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
@@ -640,6 +641,90 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
         "reason" -> "generation expression cannot contain non-default collated 
string type"))
   }
 
+  test("Aggregation on complex containing collated strings") {
+    val table = "table_agg"
+    // array
+    withTable(table) {
+      sql(s"create table $table (a array<string collate utf8_binary_lcase>) 
using parquet")
+      sql(s"insert into $table values (array('aaa')), (array('AAA'))")
+      checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa"))))
+    }
+    // map doesn't support aggregation
+    withTable(table) {
+      sql(s"create table $table (m map<string collate utf8_binary_lcase, 
string>) using parquet")
+      val query = s"select distinct m from $table"
+      checkError(
+        exception = intercept[ExtendedAnalysisException](sql(query)),
+        errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
+        parameters = Map(
+          "colName" -> "`m`",
+          "dataType" -> toSQLType(MapType(
+            
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
+            StringType))),
+        context = ExpectedContext(query, 0, query.length - 1)
+      )
+    }
+    // struct
+    withTable(table) {
+      sql(s"create table $table (s struct<fld:string collate 
utf8_binary_lcase>) using parquet")
+      sql(s"insert into $table values (named_struct('fld', 'aaa')), 
(named_struct('fld', 'AAA'))")
+      checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa")))
+    }
+  }
+
+  test("Joins on complex types containing collated strings") {
+    val tableLeft = "table_join_le"
+    val tableRight = "table_join_ri"
+    // array
+    withTable(tableLeft, tableRight) {
+      Seq(tableLeft, tableRight).map(tab =>
+        sql(s"create table $tab (a array<string collate utf8_binary_lcase>) 
using parquet"))
+      Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{
+        case (tab, data) => sql(s"insert into $tab values ($data)")
+      }
+      checkAnswer(sql(
+        s"""
+           |select $tableLeft.a from $tableLeft
+           |join $tableRight on $tableLeft.a = $tableRight.a
+           |""".stripMargin), Seq(Row(Seq("aaa"))))
+    }
+    // map doesn't support joins
+    withTable(tableLeft, tableRight) {
+      Seq(tableLeft, tableRight).map(tab =>
+        sql(s"create table $tab (m map<string collate utf8_binary_lcase, 
string>) using parquet"))
+      val query =
+        s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m 
= $tableRight.m"
+      val ctx = s"$tableLeft.m = $tableRight.m"
+      checkError(
+        exception = intercept[AnalysisException](sql(query)),
+        errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
+        parameters = Map(
+          "functionName" -> "`=`",
+          "dataType" -> toSQLType(MapType(
+            
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
+            StringType
+          )),
+          "sqlExpr" -> "\"(m = m)\""),
+        context = ExpectedContext(ctx, query.length - ctx.length, query.length 
- 1))
+    }
+    // struct
+    withTable(tableLeft, tableRight) {
+      Seq(tableLeft, tableRight).map(tab =>
+        sql(s"create table $tab (s struct<fld:string collate 
utf8_binary_lcase>) using parquet"))
+      Seq(
+        (tableLeft, "named_struct('fld', 'aaa')"),
+        (tableRight, "named_struct('fld', 'AAA')")
+      ).map {
+        case (tab, data) => sql(s"insert into $tab values ($data)")
+      }
+      checkAnswer(sql(
+        s"""
+           |select $tableLeft.s.fld from $tableLeft
+           |join $tableRight on $tableLeft.s = $tableRight.s
+           |""".stripMargin), Seq(Row("aaa")))
+    }
+  }
+
   test("window aggregates should respect collation") {
     val t1 = "T_NON_BINARY"
     val t2 = "T_BINARY"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to