diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4e..f574224b8ba96 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,5 +1,5 @@ org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +org.apache.spark.sql.execution.datasources.jdbc.JdbcSourceProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fefff68c4ba8b..51a8397dc9ee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.jdbc.JdbcSourceProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -567,7 +567,7 @@ object DataSource extends Logging { /** A map to maintain backward compatibility in case we move data sources around. */ private val backwardCompatibilityMap: Map[String, String] = { - val jdbc = classOf[JdbcRelationProvider].getCanonicalName + val jdbc = classOf[JdbcSourceProvider].getCanonicalName val json = classOf[JsonFileFormat].getCanonicalName val parquet = classOf[ParquetFileFormat].getCanonicalName val csv = classOf[CSVFileFormat].getCanonicalName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4469cb538fa6..eb2601cff7465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -190,6 +190,12 @@ class JDBCOptions( // An option to allow/disallow pushing down predicate into JDBC data source val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean + + // An option for setting max retry JDBC connection times. + val maxRetryNumber = parameters.getOrElse(JDBC_MAX_RETRY_NUMBER, "4").toInt + + // An option for setting timeout for checking connction validation + val checkValidTimeoutSeconds = parameters.getOrElse(JDBC_CHECK_VALID_TIMEOUT_SECONDS, "10").toInt } class JdbcOptionsInWrite( @@ -242,4 +248,6 @@ object JDBCOptions { val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") + val JDBC_MAX_RETRY_NUMBER = newOption("maxRetryNumber") + val JDBC_CHECK_VALID_TIMEOUT_SECONDS = newOption("checkValidTimeoutSeconds") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcSourceProvider.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcSourceProvider.scala index e7456f9c8ed0b..2692f05834901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcSourceProvider.scala @@ -17,12 +17,20 @@ package org.apache.spark.sql.execution.datasources.jdbc +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType -class JdbcRelationProvider extends CreatableRelationProvider - with RelationProvider with DataSourceRegister { +class JdbcSourceProvider extends CreatableRelationProvider + with RelationProvider + with DataSourceRegister + with StreamingWriteSupportProvider { override def shortName(): String = "jdbc" @@ -87,4 +95,15 @@ class JdbcRelationProvider extends CreatableRelationProvider createRelation(sqlContext, parameters) } + + override def createStreamingWriteSupport( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamingWriteSupport = { + val optionMap = options.asMap().asScala.toMap + // add this for parameter check. + new JDBCOptions(optionMap) + new JdbcStreamingWriteSupport(schema, optionMap) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcStreamingWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcStreamingWriteSupport.scala new file mode 100644 index 0000000000000..3c7288e31b494 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcStreamingWriteSupport.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, PreparedStatement, SQLException} +import java.util.Locale + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.makeSetter +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.types.StructType + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object JdbcWriterCommitMessage extends WriterCommitMessage + +/** + * A [[StreamingWriteSupport]] for jdbc writing. + * Responsible for generating the writer factory. + */ +class JdbcStreamingWriteSupport( + schema: StructType, + options: Map[String, String] + ) extends StreamingWriteSupport with Logging { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + logInfo(s"epoch ${epochId} of JdbcStreamWriter commited!") + } + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + logInfo(s"epoch ${epochId} of JdbcStreamWriter aborted!") + } + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + new JdbcStreamWriterFactory(schema, options) + } +} + +/** + * A [[StreamingDataWriterFactory]] for jdbc writing. + * Will be serialized and sent to executors to generate the per-task data writers. + */ +case class JdbcStreamWriterFactory( + schema: StructType, + options: Map[String, String] + ) extends StreamingDataWriterFactory { + + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + JdbcStreamDataWriter(schema, options) + } +} + +/** + * A [[DataWriter]] for Jdbc writing. + * One data writer will be created in each partition to process incoming rows. + */ +case class JdbcStreamDataWriter( + schema: StructType, + options: Map[String, String] + ) extends DataWriter[InternalRow] with Logging { + + private val jdbcOptions = new JDBCOptions(options) + + // use a local cache for batch write to jdbc. + private val batchSize = jdbcOptions.batchSize + private val localBuffer = new ArrayBuffer[Row](batchSize) + private val maxRetryNum = jdbcOptions.maxRetryNumber + + // the first part is the column name list, the second part is the placeholder string. + private val sqlPart: (String, String) = { + val columnListBuilder = new StringBuilder() + val holderListBuilder = new StringBuilder() + schema.fields.foreach { field => + columnListBuilder.append(",").append(field.name) + holderListBuilder.append(",?") + } + assert(columnListBuilder.nonEmpty, "Empty schema of data!") + (columnListBuilder.substring(1), holderListBuilder.substring(1)) + } + + private val sql: String = s"REPLACE INTO ${jdbcOptions.tableOrQuery}" + + s" ( ${sqlPart._1} ) values ( ${sqlPart._2} )" + + logTrace(s"Sql string for jdbc writing is ${sql}") + + private val dialect = JdbcDialects.get(jdbcOptions.url) + + // used for batch writing. + private var conn: Connection = _ + private var stmt: PreparedStatement = _ + + checkSchema() + + private val setters = schema.fields.map { f => + resetConnectionAndStmt() + makeSetter(conn, dialect, f.dataType) + } + private val numFields = schema.fields.length + private val nullTypes = schema.fields.map(f => + JdbcUtils.getJdbcType(f.dataType, dialect).jdbcNullType) + + // for InternalRow -> Row converting. + private val encoder = RowEncoder.apply(schema).resolveAndBind(schema.toAttributes) + + /** + * Check data schema with table. + * Data schema should equal with table schema or is a subset of table schema, + * and the column type with the same name in data schema and table scheme should be the same. + */ + private def checkSchema(): Unit = { + resetConnectionAndStmt() + val tableSchemaMap = JdbcUtils + .getSchemaOption(conn, jdbcOptions) match { + case Some(tableSchema) => + logInfo(s"Get table ${jdbcOptions.tableOrQuery}'s schema $tableSchema") + tableSchema.fields.map(field => field.name.toLowerCase(Locale.ROOT) -> field).toMap + case _ => throw new IllegalStateException( + s"Schema of table ${jdbcOptions.tableOrQuery} is not defined, make sure table exist!") + } + schema.map { field => + val tableColumn = tableSchemaMap.get(field.name.toLowerCase(Locale.ROOT)) + assert(tableColumn.isDefined, + s"Data column ${field.name} cannot be found in table ${jdbcOptions.tableOrQuery}") + assert(field.dataType == tableColumn.get.dataType, + s"Type of data column ${field.name} is not the same in table ${jdbcOptions.tableOrQuery}") + } + } + + // Using a local connection cache, avoid getting a new connection every time. + private def resetConnectionAndStmt(): Unit = { + if (conn == null || !conn.isValid(jdbcOptions.checkValidTimeoutSeconds)) { + conn = JdbcUtils.createConnectionFactory(jdbcOptions)() + stmt = conn.prepareStatement(sql) + logDebug("Current connection is invalid, create a new one.") + } else { + logDebug("Current connection is valid, reuse it.") + } + } + + override def write(record: InternalRow): Unit = { + localBuffer.append(encoder.fromRow(record)) + if (localBuffer.size == batchSize) { + logDebug(s"Local buffer is full with size $batchSize, do write and reset local buffer.") + doWriteAndResetBuffer() + } + } + + // batch write to jdbc, retry for SQLException + private def doWriteAndResetBuffer(): Unit = { + var tryNum = 0 + val size = localBuffer.size + while (tryNum <= maxRetryNum) { + try { + val start = System.currentTimeMillis() + val iterator = localBuffer.iterator + while (iterator.hasNext) { + val row = iterator.next() + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + setters(i).apply(stmt, row, i) + } + i += 1 + } + stmt.addBatch() + } + stmt.executeBatch() + localBuffer.clear() + logDebug(s"Success write $size records," + + s"retry number $tryNum, cost ${System.currentTimeMillis() - start} ms") + tryNum = maxRetryNum + 1 + } catch { + case e: SQLException => + if (tryNum <= maxRetryNum) { + tryNum += 1 + resetConnectionAndStmt() + logWarning(s"Failed to write $size records, retry number $tryNum!", e) + } else { + logError(s"Failed to write $size records," + + s"reach max retry number $maxRetryNum, abort writing!") + throw e + } + case e: Throwable => + logError(s"Failed to write $size records, not suited for retry , abort writing!", e) + throw e + } + } + } + + private def doWriteAndClose(): Unit = { + if (localBuffer.nonEmpty) { + doWriteAndResetBuffer() + } + if (conn != null) { + try { + conn.close() + } catch { + case e: Throwable => logError("Close connection with exception", e) + } + } + } + + override def commit(): WriterCommitMessage = { + doWriteAndClose() + JdbcWriterCommitMessage + } + + override def abort(): Unit = { + logInfo(s"Abort writing with ${localBuffer.size} records in local buffer.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index edea549748b47..5c2acf9bdf8ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -178,7 +178,7 @@ object JdbcUtils extends Logging { } } - private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + private[jdbc] def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) } @@ -515,7 +515,7 @@ object JdbcUtils extends Logging { // in the SQL statement and also used for the value in `Row`. private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit - private def makeSetter( + private[jdbc] def makeSetter( conn: Connection, dialect: JdbcDialect, dataType: DataType): JDBCValueSetter = dataType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCStreamWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCStreamWriteSuite.scala new file mode 100644 index 0000000000000..b42f47f306061 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCStreamWriteSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.DriverManager + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.Utils + +private case class TestData(name: String, value: Long) + +class JDBCStreamWriteSuite extends StreamTest with BeforeAndAfter { + import testImplicits._ + + val url = "jdbc:h2:mem:testdb" + val jdbcTableName = "stream_test_table" + val driverClassName = "org.h2.Driver" + val createTableSql = s""" + |CREATE TABLE ${jdbcTableName}( + | name VARCHAR(32), + | value LONG, + | PRIMARY KEY (name) + |)""".stripMargin + + var conn: java.sql.Connection = null + + val testH2Dialect = new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + } + + before { + Utils.classForName(driverClassName) + conn = DriverManager.getConnection(url) + conn.prepareStatement(createTableSql).executeUpdate() + } + + after { + conn.close() + } + + test("Basic Write") { + withTempDir { checkpointDir => { + val input = MemoryStream[Long] + val query = input.toDF().map { row => + val value = row.getLong(0) + TestData(s"name_$value", value) + }.writeStream + .format("jdbc") + .option(JDBCOptions.JDBC_URL, url) + .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName) + .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName) + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + input.addData(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + query.processAllAvailable() + } finally { + query.stop() + } + } + } + val result = conn + .prepareStatement(s"select count(*) as count from $jdbcTableName") + .executeQuery() + assert(result.next()) + assert(result.getInt("count") == 10) + } + + test("Write sub columns") { + withTempDir { checkpointDir => { + val input = MemoryStream[Long] + val query = input.toDF().map { row => + val value = row.getLong(0) + TestData(s"name_$value", value) + }.select("name").writeStream // write just one `name` column + .format("jdbc") + .option(JDBCOptions.JDBC_URL, url) + .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName) + .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName) + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + input.addData(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + query.processAllAvailable() + } finally { + query.stop() + } + } + } + val result = conn + .prepareStatement(s"select count(*) as count from $jdbcTableName") + .executeQuery() + assert(result.next()) + assert(result.getInt("count") == 10) + } + + test("Write same data") { + withTempDir { checkpointDir => { + val input = MemoryStream[Long] + val query = input.toDF().map { row => + val value = row.getLong(0) + TestData(s"name_$value", value) + }.writeStream + .format("jdbc") + .option(JDBCOptions.JDBC_URL, url) + .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName) + .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName) + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + input.addData(1, 1, 1, 1, 1, 1, 1, 1, 1, 1) + query.processAllAvailable() + } finally { + query.stop() + } + } + } + val result = conn + .prepareStatement(s"select count(*) as count from $jdbcTableName") + .executeQuery() + assert(result.next()) + assert(result.getInt("count") == 1) + } + + test("Write without required parameter") { + // without jdbc url + val thrown = intercept[StreamingQueryException] { + withTempDir { checkpointDir => { + val input = MemoryStream[Long] + val query = input.toDF().map { row => + val value = row.getLong(0) + TestData(s"name_$value", value) + }.writeStream + .format("jdbc") + .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName) + .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName) + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + } + } + assert(thrown.getMessage.contains("requirement failed: Option 'url' is required.")) + // without table name + val thrown2 = intercept[StreamingQueryException] { + withTempDir { checkpointDir => { + val input = MemoryStream[Long] + val query = input.toDF().map { row => + val value = row.getLong(0) + TestData(s"name_$value", value) + }.writeStream + .format("jdbc") + .option(JDBCOptions.JDBC_URL, url) + .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName) + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + } + } + assert(thrown2.getMessage + .contains("Option 'dbtable' or 'query' is required.")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 0aa67bf1b0d48..a894aa8ba2f17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -34,13 +34,13 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { test("jdbc") { assert( getProvidingClass("jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcSourceProvider]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcSourceProvider]) assert( getProvidingClass("org.apache.spark.sql.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcSourceProvider]) } test("json") {
With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org