This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e386b472981 [SPARK-48172][SQL] Fix escaping issues in JDBCDialects
9e386b472981 is described below

commit 9e386b472981979e368a5921c58da5bfefe3acfe
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Wed May 15 22:15:52 2024 +0800

    [SPARK-48172][SQL] Fix escaping issues in JDBCDialects
    
    This PR is a fix of https://github.com/apache/spark/pull/46437. The 
previous PR was reverted as `LONGTEXT` is not supported by all dialects.
    
    ### What changes were proposed in this pull request?
    Special case escaping for MySQL and fix issues with redundant escaping for 
' character.
    New changes introduced in the fix include change `LONGTEXT` -> 
`VARCHAR(50)`, as well as fix for table naming in the tests.
    
    ### Why are the changes needed?
    When pushing down startsWith, endsWith and contains they are converted to 
LIKE. This requires addition of escape characters for these expressions. 
Unfortunately, MySQL uses ESCAPE '\' syntax instead of ESCAPE '' which would 
cause errors when trying to push down.
    
    ### Does this PR introduce any user-facing change?
    Yes
    
    ### How was this patch tested?
    Tests for each existing dialect.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46588 from mihailom-db/SPARK-48172.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/jdbc/v2/DB2IntegrationSuite.scala    |   6 +
 .../sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala |  11 +
 .../sql/jdbc/v2/MsSqlServerIntegrationSuite.scala  |   6 +
 .../spark/sql/jdbc/v2/MySQLIntegrationSuite.scala  |   6 +
 .../spark/sql/jdbc/v2/OracleIntegrationSuite.scala |   6 +
 .../sql/jdbc/v2/PostgresIntegrationSuite.scala     |   6 +
 .../org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala  | 229 +++++++++++++++++++++
 .../sql/connector/util/V2ExpressionSQLBuilder.java |   1 -
 .../sql/connector/expressions/expressions.scala    |   4 +-
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      |   7 -
 .../org/apache/spark/sql/jdbc/MySQLDialect.scala   |  15 ++
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |   6 +-
 12 files changed, 291 insertions(+), 12 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index 3642094d11b2..57129e9d846f 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -62,6 +62,12 @@ class DB2IntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTest {
     connection.prepareStatement(
       "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary 
DECIMAL(20, 2), bonus DOUBLE)")
       .executeUpdate()
+    connection.prepareStatement(
+      s"""CREATE TABLE pattern_testing_table (
+         |pattern_testing_col VARCHAR(50)
+         |)
+                   """.stripMargin
+    ).executeUpdate()
   }
 
   override def testUpdateColumnType(tbl: String): Unit = {
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala
index 72edfc9f1bf1..5f4f0b7a3afb 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala
@@ -38,6 +38,17 @@ abstract class DockerJDBCIntegrationV2Suite extends 
DockerJDBCIntegrationSuite {
       .executeUpdate()
     connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 
1200)")
       .executeUpdate()
+
+    connection.prepareStatement(
+      s"""
+         |INSERT INTO pattern_testing_table VALUES
+         |('special_character_quote''_present'),
+         |('special_character_quote_not_present'),
+         |('special_character_percent%_present'),
+         |('special_character_percent_not_present'),
+         |('special_character_underscore_present'),
+         |('special_character_underscorenot_present')
+             """.stripMargin).executeUpdate()
   }
 
   def tablePreparation(connection: Connection): Unit
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index b1b8aec5ad33..9ddd79fb257d 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -70,6 +70,12 @@ class MsSqlServerIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JD
     connection.prepareStatement(
       "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 
2), bonus FLOAT)")
       .executeUpdate()
+    connection.prepareStatement(
+      s"""CREATE TABLE pattern_testing_table (
+         |pattern_testing_col VARCHAR(50)
+         |)
+                   """.stripMargin
+    ).executeUpdate()
   }
 
   override def notSupportsTableComment: Boolean = true
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
index 22900c7bbcc8..d5478e664221 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
@@ -73,6 +73,12 @@ class MySQLIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTest
     connection.prepareStatement(
       "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 
2)," +
         " bonus DOUBLE)").executeUpdate()
+    connection.prepareStatement(
+      s"""CREATE TABLE pattern_testing_table (
+         |pattern_testing_col LONGTEXT
+         |)
+                   """.stripMargin
+    ).executeUpdate()
   }
 
   override def testUpdateColumnType(tbl: String): Unit = {
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index b35018ec16dc..1adfef95998e 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -93,6 +93,12 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTes
     connection.prepareStatement(
       "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary 
NUMBER(20, 2)," +
         " bonus BINARY_DOUBLE)").executeUpdate()
+    connection.prepareStatement(
+      s"""CREATE TABLE pattern_testing_table (
+         |pattern_testing_col VARCHAR(50)
+         |)
+                   """.stripMargin
+    ).executeUpdate()
   }
 
   override def testUpdateColumnType(tbl: String): Unit = {
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index 1f09c2fd3fc5..7fef3ccd6b3f 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -59,6 +59,12 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCT
     connection.prepareStatement(
       "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary 
NUMERIC(20, 2)," +
         " bonus double precision)").executeUpdate()
+    connection.prepareStatement(
+      s"""CREATE TABLE pattern_testing_table (
+         |pattern_testing_col VARCHAR(50)
+         |)
+                   """.stripMargin
+    ).executeUpdate()
   }
 
   override def testUpdateColumnType(tbl: String): Unit = {
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index b60107f90283..45c4f41ffb77 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -359,6 +359,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
     assert(scan.schema.names.sameElements(Seq(col)))
   }
 
+  test("SPARK-48172: Test CONTAINS") {
+    val df1 = spark.sql(
+      s"""
+         |SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin)
+    df1.explain("formatted")
+    val rows1 = df1.collect()
+    assert(rows1.length === 1)
+    assert(rows1(0).getString(0) === "special_character_quote'_present")
+
+    val df2 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE contains(pattern_testing_col, 'percent%')""".stripMargin)
+    val rows2 = df2.collect()
+    assert(rows2.length === 1)
+    assert(rows2(0).getString(0) === "special_character_percent%_present")
+
+    val df3 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin)
+    val rows3 = df3.collect()
+    assert(rows3.length === 1)
+    assert(rows3(0).getString(0) === "special_character_underscore_present")
+
+    val df4 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE contains(pattern_testing_col, 'character')
+           |ORDER BY pattern_testing_col""".stripMargin)
+    val rows4 = df4.collect()
+    assert(rows4.length === 6)
+    assert(rows4(0).getString(0) === "special_character_percent%_present")
+    assert(rows4(1).getString(0) === "special_character_percent_not_present")
+    assert(rows4(2).getString(0) === "special_character_quote'_present")
+    assert(rows4(3).getString(0) === "special_character_quote_not_present")
+    assert(rows4(4).getString(0) === "special_character_underscore_present")
+    assert(rows4(5).getString(0) === "special_character_underscorenot_present")
+  }
+
+  test("SPARK-48172: Test ENDSWITH") {
+    val df1 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE endswith(pattern_testing_col, 
'quote\\'_present')""".stripMargin)
+    val rows1 = df1.collect()
+    assert(rows1.length === 1)
+    assert(rows1(0).getString(0) === "special_character_quote'_present")
+
+    val df2 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE endswith(pattern_testing_col, 
'percent%_present')""".stripMargin)
+    val rows2 = df2.collect()
+    assert(rows2.length === 1)
+    assert(rows2(0).getString(0) === "special_character_percent%_present")
+
+    val df3 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE endswith(pattern_testing_col, 
'underscore_present')""".stripMargin)
+    val rows3 = df3.collect()
+    assert(rows3.length === 1)
+    assert(rows3(0).getString(0) === "special_character_underscore_present")
+
+    val df4 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE endswith(pattern_testing_col, 'present')
+           |ORDER BY pattern_testing_col""".stripMargin)
+    val rows4 = df4.collect()
+    assert(rows4.length === 6)
+    assert(rows4(0).getString(0) === "special_character_percent%_present")
+    assert(rows4(1).getString(0) === "special_character_percent_not_present")
+    assert(rows4(2).getString(0) === "special_character_quote'_present")
+    assert(rows4(3).getString(0) === "special_character_quote_not_present")
+    assert(rows4(4).getString(0) === "special_character_underscore_present")
+    assert(rows4(5).getString(0) === "special_character_underscorenot_present")
+  }
+
+  test("SPARK-48172: Test STARTSWITH") {
+    val df1 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE startswith(pattern_testing_col, 
'special_character_quote\\'')""".stripMargin)
+    val rows1 = df1.collect()
+    assert(rows1.length === 1)
+    assert(rows1(0).getString(0) === "special_character_quote'_present")
+
+    val df2 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE startswith(pattern_testing_col, 
'special_character_percent%')""".stripMargin)
+    val rows2 = df2.collect()
+    assert(rows2.length === 1)
+    assert(rows2(0).getString(0) === "special_character_percent%_present")
+
+    val df3 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE startswith(pattern_testing_col, 
'special_character_underscore_')""".stripMargin)
+    val rows3 = df3.collect()
+    assert(rows3.length === 1)
+    assert(rows3(0).getString(0) === "special_character_underscore_present")
+
+    val df4 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE startswith(pattern_testing_col, 'special_character')
+           |ORDER BY pattern_testing_col""".stripMargin)
+    val rows4 = df4.collect()
+    assert(rows4.length === 6)
+    assert(rows4(0).getString(0) === "special_character_percent%_present")
+    assert(rows4(1).getString(0) === "special_character_percent_not_present")
+    assert(rows4(2).getString(0) === "special_character_quote'_present")
+    assert(rows4(3).getString(0) === "special_character_quote_not_present")
+    assert(rows4(4).getString(0) === "special_character_underscore_present")
+    assert(rows4(5).getString(0) === "special_character_underscorenot_present")
+  }
+
+  test("SPARK-48172: Test LIKE") {
+    // this one should map to contains
+    val df1 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin)
+    val rows1 = df1.collect()
+    assert(rows1.length === 1)
+    assert(rows1(0).getString(0) === "special_character_quote'_present")
+
+    val df2 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin)
+    val rows2 = df2.collect()
+    assert(rows2.length === 1)
+    assert(rows2(0).getString(0) === "special_character_percent%_present")
+
+    val df3 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin)
+    val rows3 = df3.collect()
+    assert(rows3.length === 1)
+    assert(rows3(0).getString(0) === "special_character_underscore_present")
+
+    val df4 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE '%character%'
+           |ORDER BY pattern_testing_col""".stripMargin)
+    val rows4 = df4.collect()
+    assert(rows4.length === 6)
+    assert(rows4(0).getString(0) === "special_character_percent%_present")
+    assert(rows4(1).getString(0) === "special_character_percent_not_present")
+    assert(rows4(2).getString(0) === "special_character_quote'_present")
+    assert(rows4(3).getString(0) === "special_character_quote_not_present")
+    assert(rows4(4).getString(0) === "special_character_underscore_present")
+    assert(rows4(5).getString(0) === "special_character_underscorenot_present")
+
+    // map to startsWith
+    // this one should map to contains
+    val df5 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE 
'special_character_quote\\'%'""".stripMargin)
+    val rows5 = df5.collect()
+    assert(rows5.length === 1)
+    assert(rows5(0).getString(0) === "special_character_quote'_present")
+
+    val df6 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE 
'special_character_percent\\%%'""".stripMargin)
+    val rows6 = df6.collect()
+    assert(rows6.length === 1)
+    assert(rows6(0).getString(0) === "special_character_percent%_present")
+
+    val df7 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE 
'special_character_underscore\\_%'""".stripMargin)
+    val rows7 = df7.collect()
+    assert(rows7.length === 1)
+    assert(rows7(0).getString(0) === "special_character_underscore_present")
+
+    val df8 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE 'special_character%'
+           |ORDER BY pattern_testing_col""".stripMargin)
+    val rows8 = df8.collect()
+    assert(rows8.length === 6)
+    assert(rows8(0).getString(0) === "special_character_percent%_present")
+    assert(rows8(1).getString(0) === "special_character_percent_not_present")
+    assert(rows8(2).getString(0) === "special_character_quote'_present")
+    assert(rows8(3).getString(0) === "special_character_quote_not_present")
+    assert(rows8(4).getString(0) === "special_character_underscore_present")
+    assert(rows8(5).getString(0) === "special_character_underscorenot_present")
+    // map to endsWith
+    // this one should map to contains
+    val df9 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin)
+    val rows9 = df9.collect()
+    assert(rows9.length === 1)
+    assert(rows9(0).getString(0) === "special_character_quote'_present")
+
+    val df10 = spark.sql(
+      s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+         |WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin)
+    val rows10 = df10.collect()
+    assert(rows10.length === 1)
+    assert(rows10(0).getString(0) === "special_character_percent%_present")
+
+    val df11 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE 
'%underscore\\_present'""".stripMargin)
+    val rows11 = df11.collect()
+    assert(rows11.length === 1)
+    assert(rows11(0).getString(0) === "special_character_underscore_present")
+
+    val df12 = spark.
+      sql(
+        s"""SELECT * FROM 
$catalogAndNamespace.${caseConvert("pattern_testing_table")}
+           |WHERE pattern_testing_col LIKE '%present' ORDER BY 
pattern_testing_col""".stripMargin)
+    val rows12 = df12.collect()
+    assert(rows12.length === 6)
+    assert(rows12(0).getString(0) === "special_character_percent%_present")
+    assert(rows12(1).getString(0) === "special_character_percent_not_present")
+    assert(rows12(2).getString(0) === "special_character_quote'_present")
+    assert(rows12(3).getString(0) === "special_character_quote_not_present")
+    assert(rows12(4).getString(0) === "special_character_underscore_present")
+    assert(rows12(5).getString(0) === 
"special_character_underscorenot_present")
+  }
+
   test("SPARK-37038: Test TABLESAMPLE") {
     if (supportsTableSample) {
       withTable(s"$catalogName.new_table") {
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index e42d9193ea39..11f4389245d9 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -65,7 +65,6 @@ public class V2ExpressionSQLBuilder {
       switch (c) {
         case '_' -> builder.append("\\_");
         case '%' -> builder.append("\\%");
-        case '\'' -> builder.append("\\\'");
         default -> builder.append(c);
       }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
index fc41d5a98e4a..b43e627c0eec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connector.expressions
 
+import org.apache.commons.lang3.StringUtils
+
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -388,7 +390,7 @@ private[sql] object HoursTransform {
 private[sql] final case class LiteralValue[T](value: T, dataType: DataType) 
extends Literal[T] {
   override def toString: String = {
     if (dataType.isInstanceOf[StringType]) {
-      s"'$value'"
+      s"'${StringUtils.replace(s"$value", "'", "''")}'"
     } else {
       s"$value"
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index ebfc6093dc16..949455b248ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -259,13 +259,6 @@ private[sql] case class H2Dialect() extends JdbcDialect {
   }
 
   class H2SQLBuilder extends JDBCSQLBuilder {
-    override def escapeSpecialCharsForLikePattern(str: String): String = {
-      str.map {
-        case '_' => "\\_"
-        case '%' => "\\%"
-        case c => c.toString
-      }.mkString
-    }
 
     override def visitAggregateFunction(
         funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index d98fcdfd0b23..50951042737a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -66,6 +66,21 @@ private case class MySQLDialect() extends JdbcDialect with 
SQLConfHelper {
       }
     }
 
+    override def visitStartsWith(l: String, r: String): String = {
+      val value = r.substring(1, r.length() - 1)
+      s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
+    }
+
+    override def visitEndsWith(l: String, r: String): String = {
+      val value = r.substring(1, r.length() - 1)
+      s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'"
+    }
+
+    override def visitContains(l: String, r: String): String = {
+      val value = r.substring(1, r.length() - 1)
+      s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
+    }
+
     override def visitAggregateFunction(
         funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
       if (isDistinct && 
distinctUnsupportedAggregateFunctions.contains(funcName)) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 1b3672cdba5a..8e98181a9802 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -1305,7 +1305,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     val df5 = 
spark.table("h2.test.address").filter($"email".startsWith("abc_'%"))
     checkFiltersRemoved(df5)
     checkPushedInfo(df5,
-      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\'\%%' ESCAPE 
'\']")
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_''\%%' ESCAPE 
'\']")
     checkAnswer(df5, Seq(Row("abc_'%d...@gmail.com")))
 
     val df6 = 
spark.table("h2.test.address").filter($"email".endsWith("_...@gmail.com"))
@@ -1336,7 +1336,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     val df10 = 
spark.table("h2.test.address").filter($"email".endsWith("_'%d...@gmail.com"))
     checkFiltersRemoved(df10)
     checkPushedInfo(df10,
-      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'%\_\'\%d...@gmail.com' ESCAPE '\']")
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 
'%\_''\%d...@gmail.com' ESCAPE '\']")
     checkAnswer(df10, Seq(Row("abc_'%d...@gmail.com")))
 
     val df11 = spark.table("h2.test.address").filter($"email".contains("c_d"))
@@ -1364,7 +1364,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     val df15 = 
spark.table("h2.test.address").filter($"email".contains("c_'%d"))
     checkFiltersRemoved(df15)
     checkPushedInfo(df15,
-      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\'\%d%' ESCAPE 
'\']")
+      raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_''\%d%' ESCAPE 
'\']")
     checkAnswer(df15, Seq(Row("abc_'%d...@gmail.com")))
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to