cloud-fan commented on code in PR #46599:
URL: https://github.com/apache/spark/pull/46599#discussion_r1617637460


##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("hash join should be used for collated strings") {
+    val t1 = "T_1"
+    val t2 = "T_2"
+
+    case class HashJoinTestCase[R](collation: String, result: R)
+    val testCases = Seq(
+      HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2), 
Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, 
"aa", 2)))
+    )
+
+    testCases.foreach(t => {
+      withTable(t1, t2) {
+        sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING 
PARQUET")
+        sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+
+        sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING 
PARQUET")
+        sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+
+        val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+        checkAnswer(df, t.result)
+
+        val queryPlan = df.queryExecution.executedPlan
+
+        // confirm that hash join is used instead of sort merge join
+        assert(
+          collectFirst(queryPlan) {
+            case _: BroadcastHashJoinExec => ()

Review Comment:
   we don't care shuffle or broadcast, matching `HashJoin` should be better.



##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("hash join should be used for collated strings") {
+    val t1 = "T_1"
+    val t2 = "T_2"
+
+    case class HashJoinTestCase[R](collation: String, result: R)
+    val testCases = Seq(
+      HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2), 
Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
+      HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, 
"aa", 2)))
+    )
+
+    testCases.foreach(t => {
+      withTable(t1, t2) {
+        sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING 
PARQUET")
+        sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+
+        sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING 
PARQUET")
+        sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+
+        val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+        checkAnswer(df, t.result)
+
+        val queryPlan = df.queryExecution.executedPlan
+
+        // confirm that hash join is used instead of sort merge join
+        assert(
+          collectFirst(queryPlan) {
+            case _: BroadcastHashJoinExec => ()
+          }.nonEmpty
+        )
+        assert(
+          collectFirst(queryPlan) {
+            case _: SortMergeJoinExec => ()
+          }.isEmpty
+        )
+
+        // if collation doesn't support binary equality, collation key should 
be injected
+        if 
(!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) {
+          assert(collectFirst(queryPlan) {
+            case b: BroadcastHashJoinExec => b.leftKeys.head
+          }.head.isInstanceOf[CollationKey])
+        }
+      }
+    })
+  }
+
+  test("rewrite with collationkey should be an excludable rule") {
+    val t1 = "T_1"
+    val t2 = "T_2"
+    val collation = "UTF8_BINARY_LCASE"
+    val collationRewriteJoinRule = 
"org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin"
+    withTable(t1, t2) {
+      withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> 
collationRewriteJoinRule) {
+        sql(s"CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING 
PARQUET")
+        sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+
+        sql(s"CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING 
PARQUET")
+        sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+
+        val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+        checkAnswer(df, Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))
+
+        val queryPlan = df.queryExecution.executedPlan
+
+        // confirm that shuffle join is used instead of hash join
+        assert(
+          collectFirst(queryPlan) {
+            case _: BroadcastHashJoinExec => ()

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to