This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2b0e841a3534 [SPARK-47496][SQL] Java SPI Support for dynamic JDBC dialect registering 2b0e841a3534 is described below commit 2b0e841a35343343c82e8ca15225014b64d8c59f Author: Kent Yao <y...@apache.org> AuthorDate: Thu Mar 21 19:34:28 2024 +0800 [SPARK-47496][SQL] Java SPI Support for dynamic JDBC dialect registering ### What changes were proposed in this pull request? This PR brings the Java ServiceProvider Interface (SPI) Support for dynamic JDBC dialect registering. A custom JDBC dialect can be registered easily instead of calling JdbcDialects.registerDialect manually. ### Why are the changes needed? For pure SQL and other non-Java API users, it's difficult to register a custom JDBC dialect to use. With this patch, this can be done when the jar containing the dialect class is visible to the spark classloader. ### Does this PR introduce _any_ user-facing change? Yes, but mostly for third-party developers ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45626 from yaooqinn/SPARK-47496. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Kent Yao <y...@apache.org> --- project/MimaExcludes.scala | 10 +++- .../services/org.apache.spark.sql.jdbc.JdbcDialect | 29 ++++++++++ .../org/apache/spark/sql/jdbc/DB2Dialect.scala | 2 +- .../apache/spark/sql/jdbc/DatabricksDialect.scala | 2 +- .../org/apache/spark/sql/jdbc/DerbyDialect.scala | 2 +- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 2 +- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++---- .../apache/spark/sql/jdbc/MsSqlServerDialect.scala | 20 +++---- .../org/apache/spark/sql/jdbc/MySQLDialect.scala | 2 +- .../org/apache/spark/sql/jdbc/OracleDialect.scala | 18 ++++--- .../apache/spark/sql/jdbc/PostgresDialect.scala | 2 +- .../apache/spark/sql/jdbc/SnowflakeDialect.scala | 2 +- .../apache/spark/sql/jdbc/TeradataDialect.scala | 2 +- .../services/org.apache.spark.sql.jdbc.JdbcDialect | 20 +++++++ .../spark/sql/jdbc/DummyDatabaseDialect.scala} | 18 +------ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 61 +++++++++++++--------- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 52 +++++++++--------- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 4 +- 18 files changed, 163 insertions(+), 105 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 225a13cd3537..630dd1d77cc7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -82,7 +82,15 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.scoreLabelsWeight"), // SPARK-46938: Javax -> Jakarta namespace change. ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.ProxyRedirectHandler$ResponseWrapper"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ui.ProxyRedirectHandler#ResponseWrapper.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ui.ProxyRedirectHandler#ResponseWrapper.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.DB2Dialect#DB2SQLBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.DB2Dialect#DB2SQLQueryBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MsSqlServerDialect#MsSqlServerSQLBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MsSqlServerDialect#MsSqlServerSQLQueryBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MySQLDialect#MySQLSQLBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MySQLDialect#MySQLSQLQueryBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.OracleDialect#OracleSQLBuilder.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.OracleDialect#OracleSQLQueryBuilder.this") ) // Default exclude rules diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect new file mode 100644 index 000000000000..0b9dda2d14f2 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +org.apache.spark.sql.jdbc.MySQLDialect +org.apache.spark.sql.jdbc.PostgresDialect +org.apache.spark.sql.jdbc.DB2Dialect +org.apache.spark.sql.jdbc.MsSqlServerDialect +org.apache.spark.sql.jdbc.DerbyDialect +org.apache.spark.sql.jdbc.OracleDialect +org.apache.spark.sql.jdbc.TeradataDialect +org.apache.spark.sql.jdbc.H2Dialect +org.apache.spark.sql.jdbc.SnowflakeDialect +org.apache.spark.sql.jdbc.DatabricksDialect diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 62c31b1c4c5d..31a7c783ba60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ -private object DB2Dialect extends JdbcDialect { +private case class DB2Dialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala index c905374c1678..54b8c2622827 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ -private case object DatabricksDialect extends JdbcDialect { +private case class DatabricksDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = { url.startsWith("jdbc:databricks") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 545cbf265bb0..36af0e6aeaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.types._ -private object DerbyDialect extends JdbcDialect { +private case class DerbyDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index f4a1650b3e8c..ebfc6093dc16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, N import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, MetadataBuilder, ShortType, StringType, TimestampType} -private[sql] object H2Dialect extends JdbcDialect { +private[sql] case class H2Dialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7d2812d48cae..845161c81ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Driver, Statement, Timestamp} import java.time.{Instant, LocalDate, LocalDateTime} import java.util +import java.util.ServiceLoader import scala.collection.mutable.ArrayBuilder import scala.util.control.NonFatal @@ -46,6 +47,7 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProv import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -825,16 +827,14 @@ object JdbcDialects { private[this] var dialects = List[JdbcDialect]() - registerDialect(MySQLDialect) - registerDialect(PostgresDialect) - registerDialect(DB2Dialect) - registerDialect(MsSqlServerDialect) - registerDialect(DerbyDialect) - registerDialect(OracleDialect) - registerDialect(TeradataDialect) - registerDialect(H2Dialect) - registerDialect(SnowflakeDialect) - registerDialect(DatabricksDialect) + private def registerDialects(): Unit = { + val loader = ServiceLoader.load(classOf[JdbcDialect], Utils.getContextOrSparkClassLoader) + val iter = loader.iterator() + while (iter.hasNext) { + registerDialect(iter.next()) + } + } + registerDialects() /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index aaee6be24e61..1b6dc1af9ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -29,18 +29,11 @@ import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, Sor import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY} import org.apache.spark.sql.types._ -private object MsSqlServerDialect extends JdbcDialect { - - // Special JDBC types in Microsoft SQL Server. - // https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java - private object SpecificTypes { - val GEOMETRY = -157 - val GEOGRAPHY = -158 - } - +private case class MsSqlServerDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") @@ -113,7 +106,7 @@ private object MsSqlServerDialect extends JdbcDialect { // Reference doc: https://learn.microsoft.com/en-us/sql/t-sql/data-types case java.sql.Types.SMALLINT | java.sql.Types.TINYINT => Some(ShortType) case java.sql.Types.REAL => Some(FloatType) - case SpecificTypes.GEOMETRY | SpecificTypes.GEOGRAPHY => Some(BinaryType) + case GEOMETRY | GEOGRAPHY => Some(BinaryType) case _ => None } } @@ -226,3 +219,10 @@ private object MsSqlServerDialect extends JdbcDialect { override def supportsLimit: Boolean = true } + +private object MsSqlServerDialect { + // Special JDBC types in Microsoft SQL Server. + // https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java + final val GEOMETRY = -157 + final val GEOGRAPHY = -158 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index a245458a5cb4..292e3ca2d5e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ -private case object MySQLDialect extends JdbcDialect with SQLConfHelper { +private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 544c0197dec9..a9c246c93879 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -25,17 +25,11 @@ import scala.util.control.NonFatal import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.jdbc.OracleDialect._ import org.apache.spark.sql.types._ -private case object OracleDialect extends JdbcDialect { - private[jdbc] val BINARY_FLOAT = 100 - private[jdbc] val BINARY_DOUBLE = 101 - private[jdbc] val TIMESTAMP_TZ = -101 - // oracle.jdbc.OracleType.TIMESTAMP_WITH_LOCAL_TIME_ZONE - private[jdbc] val TIMESTAMP_LTZ = -102 - - +private case class OracleDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") @@ -230,3 +224,11 @@ private case object OracleDialect extends JdbcDialect { override def supportsOffset: Boolean = true } + +private[jdbc] object OracleDialect { + final val BINARY_FLOAT = 100 + final val BINARY_DOUBLE = 101 + final val TIMESTAMP_TZ = -101 + // oracle.jdbc.OracleType.TIMESTAMP_WITH_LOCAL_TIME_ZONE + final val TIMESTAMP_LTZ = -102 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index c9737867d3e0..5c949b28ba7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ -private object PostgresDialect extends JdbcDialect with SQLConfHelper { +private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala index d8a8fe6ba4a9..276364d5d89e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types.{BooleanType, DataType} -private case object SnowflakeDialect extends JdbcDialect { +private case class SnowflakeDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:snowflake") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 0f0812bdaeb9..7acd22a3f10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.types._ -private case object TeradataDialect extends JdbcDialect { +private case class TeradataDialect() extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect new file mode 100644 index 000000000000..ce96a578e50c --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +org.apache.spark.sql.jdbc.DummyDatabaseDialect diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala similarity index 56% copy from sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala copy to sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala index d8a8fe6ba4a9..a8bca85dcb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala @@ -17,20 +17,6 @@ package org.apache.spark.sql.jdbc -import java.util.Locale - -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.types.{BooleanType, DataType} - -private case object SnowflakeDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = - url.toLowerCase(Locale.ROOT).startsWith("jdbc:snowflake") - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case BooleanType => - // By default, BOOLEAN is mapped to BIT(1). - // but Snowflake does not have a BIT type. It uses BOOLEAN instead. - Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - case _ => JdbcUtils.getCommonJDBCType(dt) - } +class DummyDatabaseDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:dummy") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a2dac5a9e1e9..e2bdd8aee97d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -786,12 +786,12 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("Default jdbc dialect registration") { - assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) - assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) - assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) - assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) - assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect) - assert(JdbcDialects.get("test.invalid") == NoopDialect) + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") === MySQLDialect()) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") === PostgresDialect()) + assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") === DB2Dialect()) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") === MsSqlServerDialect()) + assert(JdbcDialects.get("jdbc:derby:db") === DerbyDialect()) + assert(JdbcDialects.get("test.invalid") === NoopDialect) } test("quote column names by jdbc dialect") { @@ -846,13 +846,13 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("Dialect unregister") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(H2Dialect()) try { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) } finally { - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(H2Dialect()) } } @@ -997,7 +997,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { // JDBC url is a required option but is not used in this test. val options = new JDBCOptions(Map("url" -> "jdbc:h2://host:port", "dbtable" -> "test")) assert( - OracleDialect + OracleDialect() .getJdbcSQLQueryBuilder(options) .withColumns(Array("a", "b")) .withLimit(123) @@ -1053,7 +1053,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { // JDBC url is a required option but is not used in this test. val options = new JDBCOptions(Map("url" -> "jdbc:h2://host:port", "dbtable" -> "test")) assert( - MsSqlServerDialect + MsSqlServerDialect() .getJdbcSQLQueryBuilder(options) .withColumns(Array("a", "b")) .withLimit(123) @@ -1066,7 +1066,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { // JDBC url is a required option but is not used in this test. val options = new JDBCOptions(Map("url" -> "jdbc:db2://host:port", "dbtable" -> "test")) assert( - DB2Dialect + DB2Dialect() .getJdbcSQLQueryBuilder(options) .withColumns(Array("a", "b")) .withLimit(123) @@ -1938,20 +1938,20 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("SPARK-28552: Case-insensitive database URLs in JdbcDialect") { - assert(JdbcDialects.get("jdbc:mysql://localhost/db") === MySQLDialect) - assert(JdbcDialects.get("jdbc:MySQL://localhost/db") === MySQLDialect) - assert(JdbcDialects.get("jdbc:postgresql://localhost/db") === PostgresDialect) - assert(JdbcDialects.get("jdbc:postGresql://localhost/db") === PostgresDialect) - assert(JdbcDialects.get("jdbc:db2://localhost/db") === DB2Dialect) - assert(JdbcDialects.get("jdbc:DB2://localhost/db") === DB2Dialect) - assert(JdbcDialects.get("jdbc:sqlserver://localhost/db") === MsSqlServerDialect) - assert(JdbcDialects.get("jdbc:sqlServer://localhost/db") === MsSqlServerDialect) - assert(JdbcDialects.get("jdbc:derby://localhost/db") === DerbyDialect) - assert(JdbcDialects.get("jdbc:derBy://localhost/db") === DerbyDialect) - assert(JdbcDialects.get("jdbc:oracle://localhost/db") === OracleDialect) - assert(JdbcDialects.get("jdbc:Oracle://localhost/db") === OracleDialect) - assert(JdbcDialects.get("jdbc:teradata://localhost/db") === TeradataDialect) - assert(JdbcDialects.get("jdbc:Teradata://localhost/db") === TeradataDialect) + assert(JdbcDialects.get("jdbc:mysql://localhost/db") === MySQLDialect()) + assert(JdbcDialects.get("jdbc:MySQL://localhost/db") === MySQLDialect()) + assert(JdbcDialects.get("jdbc:postgresql://localhost/db") === PostgresDialect()) + assert(JdbcDialects.get("jdbc:postGresql://localhost/db") === PostgresDialect()) + assert(JdbcDialects.get("jdbc:db2://localhost/db") === DB2Dialect()) + assert(JdbcDialects.get("jdbc:DB2://localhost/db") === DB2Dialect()) + assert(JdbcDialects.get("jdbc:sqlserver://localhost/db") === MsSqlServerDialect()) + assert(JdbcDialects.get("jdbc:sqlServer://localhost/db") === MsSqlServerDialect()) + assert(JdbcDialects.get("jdbc:derby://localhost/db") === DerbyDialect()) + assert(JdbcDialects.get("jdbc:derBy://localhost/db") === DerbyDialect()) + assert(JdbcDialects.get("jdbc:oracle://localhost/db") === OracleDialect()) + assert(JdbcDialects.get("jdbc:Oracle://localhost/db") === OracleDialect()) + assert(JdbcDialects.get("jdbc:teradata://localhost/db") === TeradataDialect()) + assert(JdbcDialects.get("jdbc:Teradata://localhost/db") === TeradataDialect()) } test("SQLContext.jdbc (deprecated)") { @@ -2099,7 +2099,8 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("SPARK-45139: DatabricksDialect url handling") { - assert(JdbcDialects.get("jdbc:databricks://account.cloud.databricks.com") == DatabricksDialect) + assert(JdbcDialects.get("jdbc:databricks://account.cloud.databricks.com") === + DatabricksDialect()) } test("SPARK-45139: DatabricksDialect catalyst type mapping") { @@ -2154,4 +2155,12 @@ class JDBCSuite extends QueryTest with SharedSparkSession { val expected = Map("percentile_approx_val" -> 49) assert(namedObservation.get === expected) } + + test("SPARK-47496: ServiceLoader support for JDBC dialects") { + var dialect = JdbcDialects.get("jdbc:dummy:dummy_host:dummy_port/dummy_db") + assert(dialect.isInstanceOf[DummyDatabaseDialect]) + JdbcDialects.unregisterDialect(dialect) + dialect = JdbcDialects.get("jdbc:dummy:dummy_host:dummy_port/dummy_db") + assert(dialect === NoopDialect) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 7bae2d77a161..1b3672cdba5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -52,8 +52,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++ Array.fill(15)(0.toByte) + private val h2Dialect = JdbcDialects.get(url).asInstanceOf[H2Dialect] + val testH2Dialect = new JdbcDialect { - override def canHandle(url: String): Boolean = H2Dialect.canHandle(url) + val h2 = JdbcDialects.get(url).asInstanceOf[H2Dialect] + + override def canHandle(url: String): Boolean = h2.canHandle(url) override def supportsLimit: Boolean = false @@ -102,7 +106,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions + override def functions: Seq[(String, UnboundFunction)] = h2.functions } case object CharLength extends ScalarFunction[Int] { @@ -225,15 +229,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel stmt.setBytes(2, testBytes) stmt.executeUpdate() } - H2Dialect.registerFunction("my_avg", IntegralAverage) - H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) - H2Dialect.registerFunction("my_strlen_magic", StrLen(CharLengthWithMagicMethod)) - H2Dialect.registerFunction( + h2Dialect.registerFunction("my_avg", IntegralAverage) + h2Dialect.registerFunction("my_strlen", StrLen(CharLength)) + h2Dialect.registerFunction("my_strlen_magic", StrLen(CharLengthWithMagicMethod)) + h2Dialect.registerFunction( "my_strlen_static_magic", StrLen(new JavaStrLenStaticMagic())) } override def afterAll(): Unit = { - H2Dialect.clearFunctions() + h2Dialect.clearFunctions() Utils.deleteRecursively(tempDir) super.afterAll() } @@ -340,7 +344,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df5, "PushedFilters: []") checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df6 = spark.read.table("h2.test.employee") @@ -350,7 +354,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } @@ -437,7 +441,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df6, "PushedFilters: []") checkAnswer(df6, Seq(Row(10000.00, 1300.0, "dav"), Row(9000.00, 1200.0, "cat"))) - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df7 = spark.read @@ -450,7 +454,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df7, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } @@ -1590,7 +1594,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with filter push-down with UDF") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 2") @@ -1610,12 +1614,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } test("scan with filter push-down with UDF that has magic method") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_magic(name) > 2") @@ -1636,12 +1640,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } test("scan with filter push-down with UDF that has static magic method") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_static_magic(name) > 2") @@ -1662,7 +1666,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } @@ -2872,8 +2876,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("register dialect specific functions") { - JdbcDialects.unregisterDialect(H2Dialect) + test("register h2Dialect specific functions") { + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df = sql("SELECT h2.my_avg(id) FROM h2.test.people") @@ -2905,12 +2909,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel stop = 20)) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } test("scan with aggregate push-down: complete push-down UDAF") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people") @@ -2959,7 +2963,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } @@ -3006,7 +3010,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("IDENTIFIER_TOO_MANY_NAME_PARTS: " + "jdbc function doesn't support identifiers consisting of more than 2 parts") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(h2Dialect) try { JdbcDialects.registerDialect(testH2Dialect) checkError( @@ -3019,7 +3023,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel ) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(h2Dialect) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index f904d0e3d3c8..0d9dc2f76faf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -206,7 +206,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { } test("Truncate") { - JdbcDialects.unregisterDialect(H2Dialect) + JdbcDialects.unregisterDialect(H2Dialect()) try { JdbcDialects.registerDialect(testH2Dialect) val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) @@ -231,7 +231,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { "Some(StructType(StructField(name,StringType,true),StructField(id,IntegerType,true)))")) } finally { JdbcDialects.unregisterDialect(testH2Dialect) - JdbcDialects.registerDialect(H2Dialect) + JdbcDialects.registerDialect(H2Dialect()) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org