Repository: spark
Updated Branches:
  refs/heads/master d5007734b -> e6dc5f280


[SPARK-22729][SQL] Add getTruncateQuery to JdbcDialect

In order to enable truncate for PostgreSQL databases in Spark JDBC, a change is 
needed to the query used for truncating a PostgreSQL table. By default, 
PostgreSQL will automatically truncate any descendant tables if a TRUNCATE 
query is executed. As this may result in (unwanted) side-effects, the query 
used for the truncate should be specified separately for PostgreSQL, specifying 
only to TRUNCATE a single table.

## What changes were proposed in this pull request?

Add `getTruncateQuery` function to `JdbcDialect.scala`, with default query. 
Overridden this function for PostgreSQL to only truncate a single table. Also 
sets `isCascadingTruncateTable` to false, as this will allow truncates for 
PostgreSQL.

## How was this patch tested?

Existing tests all pass. Added test for `getTruncateQuery`

Author: Daniel van der Ende <daniel.vandere...@gmail.com>

Closes #19911 from danielvdende/SPARK-22717.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6dc5f28
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6dc5f28
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6dc5f28

Branch: refs/heads/master
Commit: e6dc5f280786bdb068abb7348b787324f76c0e4a
Parents: d500773
Author: Daniel van der Ende <daniel.vandere...@gmail.com>
Authored: Tue Dec 12 10:41:37 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 12 10:41:37 2017 -0800

----------------------------------------------------------------------
 .../datasources/jdbc/JdbcRelationProvider.scala     |  2 +-
 .../sql/execution/datasources/jdbc/JdbcUtils.scala  |  7 ++++---
 .../apache/spark/sql/jdbc/AggregatedDialect.scala   |  6 +++++-
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala    | 12 ++++++++++++
 .../org/apache/spark/sql/jdbc/PostgresDialect.scala | 14 ++++++++++++--
 .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++
 6 files changed, 50 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 37e7bb0..cc506e5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
           case SaveMode.Overwrite =>
             if (options.isTruncate && isCascadingTruncateTable(options.url) == 
Some(false)) {
               // In this case, we should truncate table and then load.
-              truncateTable(conn, options.table)
+              truncateTable(conn, options)
               val tableSchema = JdbcUtils.getSchemaOption(conn, options)
               saveTable(df, tableSchema, isCaseSensitive, options)
             } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index bbc95df..e6dc2fd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -96,12 +96,13 @@ object JdbcUtils extends Logging {
   }
 
   /**
-   * Truncates a table from the JDBC database.
+   * Truncates a table from the JDBC database without side effects.
    */
-  def truncateTable(conn: Connection, table: String): Unit = {
+  def truncateTable(conn: Connection, options: JDBCOptions): Unit = {
+    val dialect = JdbcDialects.get(options.url)
     val statement = conn.createStatement
     try {
-      statement.executeUpdate(s"TRUNCATE TABLE $table")
+      statement.executeUpdate(dialect.getTruncateQuery(options.table))
     } finally {
       statement.close()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
index f3bfea5..8b92c8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder}
 /**
  * AggregatedDialect can unify multiple dialects into one virtual Dialect.
  * Dialects are tried in order, and the first dialect that does not return a
- * neutral element will will.
+ * neutral element will win.
  *
  * @param dialects List of dialects.
  */
@@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) 
extends JdbcDialect
       case _ => Some(false)
     }
   }
+
+  override def getTruncateQuery(table: String): String = {
+    dialects.head.getTruncateQuery(table)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 7c38ed6..83d87a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -117,6 +117,18 @@ abstract class JdbcDialect extends Serializable {
   }
 
   /**
+   * The SQL query that should be used to truncate a table. Dialects can 
override this method to
+   * return a query that is suitable for a particular database. For 
PostgreSQL, for instance,
+   * a different query is used to prevent "TRUNCATE" affecting other tables.
+   * @param table The name of the table.
+   * @return The SQL query to use for truncating a table
+   */
+  @Since("2.3.0")
+  def getTruncateQuery(table: String): String = {
+    s"TRUNCATE TABLE $table"
+  }
+
+  /**
    * Override connection specific properties to run before a select is made.  
This is in place to
    * allow dialects that need special treatment to optimize behavior.
    * @param connection The connection object

http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 4f61a32..13a2035 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect {
     s"SELECT 1 FROM $table LIMIT 1"
   }
 
+  /**
+  * The SQL query used to truncate a table. For Postgres, the default 
behaviour is to
+  * also truncate any descendant tables. As this is a (possibly unwanted) 
side-effect,
+  * the Postgres dialect adds 'ONLY' to truncate only the table in question
+  * @param table The name of the table.
+  * @return The SQL query to use for truncating a table
+  */
+  override def getTruncateQuery(table: String): String = {
+    s"TRUNCATE TABLE ONLY $table"
+  }
+
   override def beforeFetch(connection: Connection, properties: Map[String, 
String]): Unit = {
     super.beforeFetch(connection, properties)
 
@@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect {
     if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 
0) {
       connection.setAutoCommit(false)
     }
-
   }
 
-  override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
+  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 0767ca1..cb2df0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -854,6 +854,22 @@ class JDBCSuite extends SparkFunSuite
     assert(derby.getTableExistsQuery(table) == defaultQuery)
   }
 
+  test("truncate table query by jdbc dialect") {
+    val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
+    val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
+    val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
+    val h2 = JdbcDialects.get(url)
+    val derby = JdbcDialects.get("jdbc:derby:db")
+    val table = "weblogs"
+    val defaultQuery = s"TRUNCATE TABLE $table"
+    val postgresQuery = s"TRUNCATE TABLE ONLY $table"
+    assert(MySQL.getTruncateQuery(table) == defaultQuery)
+    assert(Postgres.getTruncateQuery(table) == postgresQuery)
+    assert(db2.getTruncateQuery(table) == defaultQuery)
+    assert(h2.getTruncateQuery(table) == defaultQuery)
+    assert(derby.getTruncateQuery(table) == defaultQuery)
+  }
+
   test("Test DataFrame.where for Date and Timestamp") {
     // Regression test for bug SPARK-11788
     val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to