Repository: spark Updated Branches: refs/heads/master d5007734b -> e6dc5f280
[SPARK-22729][SQL] Add getTruncateQuery to JdbcDialect In order to enable truncate for PostgreSQL databases in Spark JDBC, a change is needed to the query used for truncating a PostgreSQL table. By default, PostgreSQL will automatically truncate any descendant tables if a TRUNCATE query is executed. As this may result in (unwanted) side-effects, the query used for the truncate should be specified separately for PostgreSQL, specifying only to TRUNCATE a single table. ## What changes were proposed in this pull request? Add `getTruncateQuery` function to `JdbcDialect.scala`, with default query. Overridden this function for PostgreSQL to only truncate a single table. Also sets `isCascadingTruncateTable` to false, as this will allow truncates for PostgreSQL. ## How was this patch tested? Existing tests all pass. Added test for `getTruncateQuery` Author: Daniel van der Ende <daniel.vandere...@gmail.com> Closes #19911 from danielvdende/SPARK-22717. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6dc5f28 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6dc5f28 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6dc5f28 Branch: refs/heads/master Commit: e6dc5f280786bdb068abb7348b787324f76c0e4a Parents: d500773 Author: Daniel van der Ende <daniel.vandere...@gmail.com> Authored: Tue Dec 12 10:41:37 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Tue Dec 12 10:41:37 2017 -0800 ---------------------------------------------------------------------- .../datasources/jdbc/JdbcRelationProvider.scala | 2 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 7 ++++--- .../apache/spark/sql/jdbc/AggregatedDialect.scala | 6 +++++- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 12 ++++++++++++ .../org/apache/spark/sql/jdbc/PostgresDialect.scala | 14 ++++++++++++-- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++ 6 files changed, 50 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 37e7bb0..cc506e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.Overwrite => if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. - truncateTable(conn, options.table) + truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options) } else { http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index bbc95df..e6dc2fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -96,12 +96,13 @@ object JdbcUtils extends Logging { } /** - * Truncates a table from the JDBC database. + * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, table: String): Unit = { + def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(s"TRUNCATE TABLE $table") + statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() } http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index f3bfea5..8b92c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** * AggregatedDialect can unify multiple dialects into one virtual Dialect. * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. + * neutral element will win. * * @param dialects List of dialects. */ @@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect case _ => Some(false) } } + + override def getTruncateQuery(table: String): String = { + dialects.head.getTruncateQuery(table) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7c38ed6..83d87a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -117,6 +117,18 @@ abstract class JdbcDialect extends Serializable { } /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + @Since("2.3.0") + def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE $table" + } + + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. * @param connection The connection object http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4f61a32..13a2035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + /** + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE ONLY $table" + } + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) @@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect { if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(true) + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } http://git-wip-us.apache.org/repos/asf/spark/blob/e6dc5f28/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0767ca1..cb2df0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -854,6 +854,22 @@ class JDBCSuite extends SparkFunSuite assert(derby.getTableExistsQuery(table) == defaultQuery) } + test("truncate table query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table" + assert(MySQL.getTruncateQuery(table) == defaultQuery) + assert(Postgres.getTruncateQuery(table) == postgresQuery) + assert(db2.getTruncateQuery(table) == defaultQuery) + assert(h2.getTruncateQuery(table) == defaultQuery) + assert(derby.getTruncateQuery(table) == defaultQuery) + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org