This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0129f34f201 [SPARK-37259][SQL] Support CTE and temp table queries with 
MSSQL JDBC
0129f34f201 is described below

commit 0129f34f2016a4d9a0f0e862d21778a26259b4d0
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Fri May 6 14:23:06 2022 -0700

    [SPARK-37259][SQL] Support CTE and temp table queries with MSSQL JDBC
    
    ### What changes were proposed in this pull request?
    Currently CTE queries from Spark are not supported with MSSQL server via 
JDBC. This is because MSSQL server doesn't support the nested CTE syntax 
(`SELECT * FROM (WITH t AS (...) SELECT ... FROM t) WHERE 1=0`) that Spark 
builds from the original query (`options.tableOrQuery`) in 
`JDBCRDD.resolveTable()`  and in `JDBCRDD.compute()`.
    Unfortunately, it is non-trivial to split an arbitrary query into "with" 
and "regular" clauses in `MsSqlServerDialect`. So instead, I'm proposing a new 
general JDBC option `prepareQuery` that users can use if they have complex 
queries:
    ```
    val df = spark.read.format("jdbc")
      .option("url", jdbcUrl)
      .option("prepareQuery", "WITH t AS (SELECT x, y FROM tbl)")
      .option("query", "SELECT * FROM t WHERE x > 10")
      .load()
    ```
    This change also works with MSSQL's temp table syntax:
    ```
    val df = spark.read.format("jdbc")
      .option("url", jdbcUrl)
      .option("prepareQuery", "(SELECT * INTO #TempTable FROM (SELECT * FROM 
tbl WHERE x > 10) t)")
      .option("query", "SELECT * FROM #TempTable")
      .load()
    ```
    
    ### Why are the changes needed?
    To support CTE and temp table queries with MSSQL.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, CTE and temp table queries are supported form now.
    
    ### How was this patch tested?
    Added new integration UTs.
    
    Closes #36440 from peter-toth/SPARK-37259-cte-mssql.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: huaxingao <huaxin_...@apple.com>
---
 .../sql/jdbc/MsSqlServerIntegrationSuite.scala     | 55 ++++++++++++++++++++++
 docs/sql-data-sources-jdbc.md                      | 31 ++++++++++++
 .../execution/datasources/jdbc/JDBCOptions.scala   |  5 ++
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   |  6 ++-
 .../execution/datasources/jdbc/JDBCRelation.scala  |  2 +-
 .../sql/execution/datasources/jdbc/JdbcUtils.scala |  3 +-
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      |  3 +-
 7 files changed, 100 insertions(+), 5 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index e293f9a8f7b..a4e2dba5343 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -21,6 +21,7 @@ import java.math.BigDecimal
 import java.sql.{Connection, Date, Timestamp}
 import java.util.Properties
 
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
@@ -374,4 +375,58 @@ class MsSqlServerIntegrationSuite extends 
DockerJDBCIntegrationSuite {
     val filtered = df.where(col("c") === 0).collect()
     assert(filtered.length == 0)
   }
+
+  test("SPARK-37259: prepareQuery and query JDBC options") {
+    val expectedResult = Set(
+      (42, "fred"),
+      (17, "dave")
+    ).map { case (x, y) =>
+      Row(Integer.valueOf(x), String.valueOf(y))
+    }
+
+    val prepareQuery = "WITH t AS (SELECT x, y FROM tbl)"
+    val query = "SELECT * FROM t WHERE x > 10"
+    val df = spark.read.format("jdbc")
+      .option("url", jdbcUrl)
+      .option("prepareQuery", prepareQuery)
+      .option("query", query)
+      .load()
+    assert(df.collect.toSet === expectedResult)
+  }
+
+  test("SPARK-37259: prepareQuery and dbtable JDBC options") {
+    val expectedResult = Set(
+      (42, "fred"),
+      (17, "dave")
+    ).map { case (x, y) =>
+      Row(Integer.valueOf(x), String.valueOf(y))
+    }
+
+    val prepareQuery = "WITH t AS (SELECT x, y FROM tbl WHERE x > 10)"
+    val dbtable = "t"
+    val df = spark.read.format("jdbc")
+      .option("url", jdbcUrl)
+      .option("prepareQuery", prepareQuery)
+      .option("dbtable", dbtable)
+      .load()
+    assert(df.collect.toSet === expectedResult)
+  }
+
+  test("SPARK-37259: temp table prepareQuery and query JDBC options") {
+    val expectedResult = Set(
+      (42, "fred"),
+      (17, "dave")
+    ).map { case (x, y) =>
+      Row(Integer.valueOf(x), String.valueOf(y))
+    }
+
+    val prepareQuery = "(SELECT * INTO #TempTable FROM (SELECT * FROM tbl) t)"
+    val query = "SELECT * FROM #TempTable"
+    val df = spark.read.format("jdbc")
+      .option("url", jdbcUrl)
+      .option("prepareQuery", prepareQuery)
+      .option("query", query)
+      .load()
+    assert(df.collect.toSet === expectedResult)
+  }
 }
diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index e17c8f686fc..3b83bf5bc14 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -98,6 +98,37 @@ logging into the data sources.
     </td>
     <td>read/write</td>
   </tr>
+  <tr>
+    <td><code>prepareQuery</code></td>
+    <td>(none)</td>
+    <td>
+      A prefix that will form the final query together with <code>query</code>.
+      As the specified <code>query</code> will be parenthesized as a subquery 
in the <code>FROM</code> clause and some databases do not 
+      support all clauses in subqueries, the <code>prepareQuery</code> 
property offers a way to run such complex queries.
+      As an example, spark will issue a query of the following form to the 
JDBC Source.<br><br>
+      <code>&lt;prepareQuery&gt; SELECT &lt;columns&gt; FROM 
(&lt;user_specified_query&gt;) spark_gen_alias</code><br><br>
+      Below are a couple of examples.<br>
+      <ol>
+         <li> MSSQL Server does not accept <code>WITH</code> clauses in 
subqueries but it is possible to split such a query to 
<code>prepareQuery</code> and <code>query</code>:<br>
+            <code>
+               spark.read.format("jdbc")<br>
+                 .option("url", jdbcUrl)<br>
+                 .option("prepareQuery", "WITH t AS (SELECT x, y FROM 
tbl)")<br>
+                 .option("query", "SELECT * FROM t WHERE x > 10")<br>
+                 .load()
+            </code></li>
+         <li> MSSQL Server does not accept temp table clauses in subqueries 
but it is possible to split such a query to <code>prepareQuery</code> and 
<code>query</code>:<br>
+            <code>
+               spark.read.format("jdbc")<br>
+                 .option("url", jdbcUrl)<br>
+                 .option("prepareQuery", "(SELECT * INTO #TempTable FROM 
(SELECT * FROM tbl) t)")<br>
+                 .option("query", "SELECT * FROM #TempTable")<br>
+                 .load()
+            </code></li>
+      </ol>
+    </td>
+    <td>read/write</td>
+  </tr>
 
   <tr>
     <td><code>driver</code></td>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index ad44048ce9c..df21a9820f9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -222,6 +222,10 @@ class JDBCOptions(
 
   // User specified JDBC connection provider name
   val connectionProviderName = parameters.get(JDBC_CONNECTION_PROVIDER)
+
+  // The prefix that is added to the query sent to the JDBC database.
+  // This is required to support some complex queries with some JDBC databases.
+  val prepareQuery = parameters.get(JDBC_PREPARE_QUERY).map(_ + " 
").getOrElse("")
 }
 
 class JdbcOptionsInWrite(
@@ -282,4 +286,5 @@ object JDBCOptions {
   val JDBC_TABLE_COMMENT = newOption("tableComment")
   val JDBC_REFRESH_KRB5_CONFIG = newOption("refreshKrb5Config")
   val JDBC_CONNECTION_PROVIDER = newOption("connectionProvider")
+  val JDBC_PREPARE_QUERY = newOption("prepareQuery")
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 13d6156aed1..27d2d9c84c3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -52,9 +52,10 @@ object JDBCRDD extends Logging {
    */
   def resolveTable(options: JDBCOptions): StructType = {
     val url = options.url
+    val prepareQuery = options.prepareQuery
     val table = options.tableOrQuery
     val dialect = JdbcDialects.get(url)
-    getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
+    getQueryOutputSchema(prepareQuery + dialect.getSchemaQuery(table), 
options, dialect)
   }
 
   def getQueryOutputSchema(
@@ -304,7 +305,8 @@ private[jdbc] class JDBCRDD(
 
     val myLimitClause: String = dialect.getLimitClause(limit)
 
-    val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} 
$myTableSampleClause" +
+    val sqlText = options.prepareQuery +
+      s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
       s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
     stmt = conn.prepareStatement(sqlText,
         ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index ea841027607..427a494eb67 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -329,6 +329,6 @@ private[sql] case class JDBCRelation(
   override def toString: String = {
     val partitioningInfo = if (parts.nonEmpty) s" 
[numPartitions=${parts.length}]" else ""
     // credentials should not be included in the plan output, table 
information is sufficient.
-    s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo
+    
s"JDBCRelation(${jdbcOptions.prepareQuery}${jdbcOptions.tableOrQuery})$partitioningInfo"
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 6c67a22b8e3..1f17d4f0b14 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -239,7 +239,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
     val dialect = JdbcDialects.get(options.url)
 
     try {
-      val statement = 
conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery))
+      val statement =
+        conn.prepareStatement(options.prepareQuery + 
dialect.getSchemaQuery(options.tableOrQuery))
       try {
         statement.setQueryTimeout(options.queryTimeout)
         Some(getSchema(statement.executeQuery(), dialect))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index a09444d2a3e..3681154a1bc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -103,7 +103,8 @@ case class JDBCScanBuilder(
       "GROUP BY " + compiledGroupBys.mkString(",")
     }
 
-    val aggQuery = s"SELECT ${selectList.mkString(",")} FROM 
${jdbcOptions.tableOrQuery} " +
+    val aggQuery = jdbcOptions.prepareQuery +
+      s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
       s"WHERE 1=0 $groupByClause"
     try {
       finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, 
dialect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to