Repository: spark
Updated Branches:
  refs/heads/master 8c7e19a37 -> 17edfec59


[SPARK-20427][SQL] Read JDBC table use custom schema

## What changes were proposed in this pull request?

Auto generated Oracle schema some times not we expect:

- `number(1)` auto mapped to BooleanType, some times it's not we expect, per 
[SPARK-20921](https://issues.apache.org/jira/browse/SPARK-20921).
-  `number` auto mapped to Decimal(38,10), It can't read big data, per 
[SPARK-20427](https://issues.apache.org/jira/browse/SPARK-20427).

This PR fix this issue by custom schema as follows:
```scala
val props = new Properties()
props.put("customSchema", "ID decimal(38, 0), N1 int, N2 boolean")
val dfRead = spark.read.schema(schema).jdbc(jdbcUrl, "tableWithCustomSchema", 
props)
dfRead.show()
```
or
```sql
CREATE TEMPORARY VIEW tableWithCustomSchema
USING org.apache.spark.sql.jdbc
OPTIONS (url '$jdbcUrl', dbTable 'tableWithCustomSchema', customSchema'ID 
decimal(38, 0), N1 int, N2 boolean')
```

## How was this patch tested?

unit tests

Author: Yuming Wang <wgy...@gmail.com>

Closes #18266 from wangyum/SPARK-20427.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/17edfec5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/17edfec5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/17edfec5

Branch: refs/heads/master
Commit: 17edfec59de8d8680f7450b4d07c147c086c105a
Parents: 8c7e19a
Author: Yuming Wang <wgy...@gmail.com>
Authored: Wed Sep 13 16:34:17 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Wed Sep 13 16:34:17 2017 -0700

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   |  9 +-
 examples/src/main/python/sql/datasource.py      | 10 +++
 .../examples/sql/SQLDataSourceExample.scala     |  4 +
 .../spark/sql/jdbc/OracleIntegrationSuite.scala | 47 +++++++++--
 .../datasources/jdbc/JDBCOptions.scala          |  4 +
 .../execution/datasources/jdbc/JDBCRDD.scala    |  2 +-
 .../datasources/jdbc/JDBCRelation.scala         |  9 +-
 .../execution/datasources/jdbc/JdbcUtils.scala  | 30 ++++++-
 .../datasources/jdbc/JdbcUtilsSuite.scala       | 87 ++++++++++++++++++++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 30 +++++++
 10 files changed, 222 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 0a8acbb..95d7040 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1328,7 +1328,14 @@ the following case-insensitive options:
     <td>
      The database column data types to use instead of the defaults, when 
creating the table. Data type information should be specified in the same 
format as CREATE TABLE columns syntax (e.g: <code>"name CHAR(64), comments 
VARCHAR(1024)")</code>. The specified types should be valid spark sql data 
types. This option applies only to writing.
     </td>
-  </tr>  
+  </tr>
+
+  <tr>
+    <td><code>customSchema</code></td>
+    <td>
+     The custom schema to use for reading data from JDBC connectors. For 
example, "id DECIMAL(38, 0), name STRING"). The column names should be 
identical to the corresponding column names of JDBC table. Users can specify 
the corresponding data types of Spark SQL instead of using the defaults. This 
option applies only to reading.
+    </td>
+  </tr>
 </table>
 
 <div class="codetabs">

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/examples/src/main/python/sql/datasource.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/sql/datasource.py 
b/examples/src/main/python/sql/datasource.py
index 8777cca..f86012e 100644
--- a/examples/src/main/python/sql/datasource.py
+++ b/examples/src/main/python/sql/datasource.py
@@ -177,6 +177,16 @@ def jdbc_dataset_example(spark):
         .jdbc("jdbc:postgresql:dbserver", "schema.tablename",
               properties={"user": "username", "password": "password"})
 
+    # Specifying dataframe column data types on read
+    jdbcDF3 = spark.read \
+        .format("jdbc") \
+        .option("url", "jdbc:postgresql:dbserver") \
+        .option("dbtable", "schema.tablename") \
+        .option("user", "username") \
+        .option("password", "password") \
+        .option("customSchema", "id DECIMAL(38, 0), name STRING") \
+        .load()
+
     # Saving data to a JDBC source
     jdbcDF.write \
         .format("jdbc") \

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
 
b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index 6ff03bd..86b3dc4 100644
--- 
a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ 
b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -185,6 +185,10 @@ object SQLDataSourceExample {
     connectionProperties.put("password", "password")
     val jdbcDF2 = spark.read
       .jdbc("jdbc:postgresql:dbserver", "schema.tablename", 
connectionProperties)
+    // Specifying the custom data types of the read schema
+    connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
+    val jdbcDF3 = spark.read
+      .jdbc("jdbc:postgresql:dbserver", "schema.tablename", 
connectionProperties)
 
     // Saving data to a JDBC source
     jdbcDF.write

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
 
b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index 1b2c1b9..7680ae3 100644
--- 
a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ 
b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
 import java.util.Properties
 import java.math.BigDecimal
 
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.execution.{WholeStageCodegenExec, 
RowDataSourceScanExec}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -72,10 +72,17 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationSuite with SharedSQLCo
       """.stripMargin.replaceAll("\n", " ")).executeUpdate()
     conn.commit()
 
-    conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t 
TIMESTAMP WITH TIME ZONE)")
-        .executeUpdate()
-    conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, 
to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))")
-        .executeUpdate()
+    conn.prepareStatement(
+      "CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME 
ZONE)").executeUpdate()
+    conn.prepareStatement(
+      "INSERT INTO ts_with_timezone VALUES " +
+        "(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS 
TZR'))").executeUpdate()
+    conn.commit()
+
+    conn.prepareStatement(
+      "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 
NUMBER(1))").executeUpdate()
+    conn.prepareStatement(
+      "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 
1, 0)").executeUpdate()
     conn.commit()
 
     sql(
@@ -104,7 +111,7 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationSuite with SharedSQLCo
   }
 
 
-  test("SPARK-16625 : Importing Oracle numeric types") { 
+  test("SPARK-16625 : Importing Oracle numeric types") {
     val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
     val rows = df.collect()
     assert(rows.size == 1)
@@ -272,4 +279,32 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationSuite with SharedSQLCo
     assert(row.getDate(0).equals(dateVal))
     assert(row.getTimestamp(1).equals(timestampVal))
   }
+
+  test("SPARK-20427/SPARK-20921: read table use custom schema by jdbc api") {
+    // default will throw IllegalArgumentException
+    val e = intercept[org.apache.spark.SparkException] {
+      spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new 
Properties()).collect()
+    }
+    assert(e.getMessage.contains(
+      "requirement failed: Decimal precision 39 exceeds max precision 38"))
+
+    // custom schema can read data
+    val props = new Properties()
+    props.put("customSchema",
+      s"ID DECIMAL(${DecimalType.MAX_PRECISION}, 0), N1 INT, N2 BOOLEAN")
+    val dfRead = spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", props)
+
+    val rows = dfRead.collect()
+    // verify the data type
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types(0).equals("class java.math.BigDecimal"))
+    assert(types(1).equals("class java.lang.Integer"))
+    assert(types(2).equals("class java.lang.Boolean"))
+
+    // verify the value
+    val values = rows(0)
+    assert(values.getDecimal(0).equals(new 
java.math.BigDecimal("12312321321321312312312312123")))
+    assert(values.getInt(1).equals(1))
+    assert(values.getBoolean(2).equals(false))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 05b0005..b4e5d16 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager}
 import java.util.{Locale, Properties}
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.types.StructType
 
 /**
  * Options for the JDBC data source.
@@ -123,6 +124,8 @@ class JDBCOptions(
   // TODO: to reuse the existing partition parameters for those partition 
specific options
   val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
   val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
+  val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES)
+
   val batchSize = {
     val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
     require(size >= 1,
@@ -161,6 +164,7 @@ object JDBCOptions {
   val JDBC_TRUNCATE = newOption("truncate")
   val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
   val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
+  val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
   val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
   val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
   val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 3274be9..0532621 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -80,7 +80,7 @@ object JDBCRDD extends Logging {
    * @return A Catalyst schema corresponding to columns in the given order.
    */
   private def pruneSchema(schema: StructType, columns: Array[String]): 
StructType = {
-    val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> 
x): _*)
+    val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
     new StructType(columns.map(name => fieldMap(name)))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 17405f5..b23e5a7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -111,7 +111,14 @@ private[sql] case class JDBCRelation(
 
   override val needConversion: Boolean = false
 
-  override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
+  override val schema: StructType = {
+    val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
+    jdbcOptions.customSchema match {
+      case Some(customSchema) => JdbcUtils.getCustomSchema(
+        tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
+      case None => tableSchema
+    }
+  }
 
   // Check if JDBCRDD.compileFilter can accept input filters
   override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index bbe9024f..75327f0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.InputMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -301,7 +302,6 @@ object JdbcUtils extends Logging {
         rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
       }
       val metadata = new MetadataBuilder()
-        .putString("name", columnName)
         .putLong("scale", fieldScale)
       val columnType =
         dialect.getCatalystType(dataType, typeName, fieldSize, 
metadata).getOrElse(
@@ -768,6 +768,34 @@ object JdbcUtils extends Logging {
   }
 
   /**
+   * Parses the user specified customSchema option value to DataFrame schema,
+   * and returns it if it's all columns are equals to default schema's.
+   */
+  def getCustomSchema(
+      tableSchema: StructType,
+      customSchema: String,
+      nameEquality: Resolver): StructType = {
+    val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
+
+    SchemaUtils.checkColumnNameDuplication(
+      userSchema.map(_.name), "in the customSchema option value", nameEquality)
+
+    val colNames = tableSchema.fieldNames.mkString(",")
+    val errorMsg = s"Please provide all the columns, all columns are: 
$colNames"
+    if (userSchema.size != tableSchema.size) {
+      throw new AnalysisException(errorMsg)
+    }
+
+    // This is resolved by names, only check the column names.
+    userSchema.fieldNames.foreach { col =>
+      tableSchema.find(f => nameEquality(f.name, col)).getOrElse {
+        throw new AnalysisException(errorMsg)
+      }
+    }
+    userSchema
+  }
+
+  /**
    * Saves the RDD to the database in a single transaction.
    */
   def saveTable(

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
new file mode 100644
index 0000000..1255f26
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.types._
+
+class JdbcUtilsSuite extends SparkFunSuite {
+
+  val tableSchema = StructType(Seq(
+    StructField("C1", StringType, false), StructField("C2", IntegerType, 
false)))
+  val caseSensitive = 
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+  val caseInsensitive = 
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+
+  test("Parse user specified column types") {
+    assert(
+      JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", 
caseInsensitive) ===
+      StructType(Seq(StructField("C1", DateType, true), StructField("C2", 
StringType, true))))
+    assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", 
caseSensitive) ===
+      StructType(Seq(StructField("C1", DateType, true), StructField("C2", 
StringType, true))))
+    assert(
+      JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", 
caseInsensitive) ===
+        StructType(Seq(StructField("c1", DateType, true), StructField("C2", 
StringType, true))))
+    assert(JdbcUtils.getCustomSchema(
+      tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) ===
+      StructType(Seq(StructField("c1", DecimalType(38, 0), true),
+        StructField("C2", StringType, true))))
+
+    // Throw AnalysisException
+    val duplicate = intercept[AnalysisException]{
+      JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", 
caseInsensitive) ===
+        StructType(Seq(StructField("c1", DateType, true), StructField("c1", 
StringType, true)))
+    }
+    assert(duplicate.getMessage.contains(
+      "Found duplicate column(s) in the customSchema option value"))
+
+    val allColumns = intercept[AnalysisException]{
+      JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) ===
+        StructType(Seq(StructField("C1", DateType, true)))
+    }
+    assert(allColumns.getMessage.contains("Please provide all the columns,"))
+
+    val caseSensitiveColumnNotFound = intercept[AnalysisException]{
+      JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", 
caseSensitive) ===
+        StructType(Seq(StructField("c1", DateType, true), StructField("C2", 
StringType, true)))
+    }
+    assert(caseSensitiveColumnNotFound.getMessage.contains(
+      "Please provide all the columns, all columns are: C1,C2;"))
+
+    val caseInsensitiveColumnNotFound = intercept[AnalysisException]{
+      JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", 
caseInsensitive) ===
+        StructType(Seq(StructField("c3", DateType, true), StructField("C2", 
StringType, true)))
+    }
+    assert(caseInsensitiveColumnNotFound.getMessage.contains(
+      "Please provide all the columns, all columns are: C1,C2;"))
+
+    // Throw ParseException
+    val dataTypeNotSupported = intercept[ParseException]{
+      JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", 
caseInsensitive) ===
+        StructType(Seq(StructField("c3", DateType, true), StructField("C2", 
StringType, true)))
+    }
+    assert(dataTypeNotSupported.getMessage.contains("DataType datee is not 
supported"))
+
+    val mismatchedInput = intercept[ParseException]{
+      JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", 
caseInsensitive) ===
+        StructType(Seq(StructField("c3", DateType, true), StructField("C2", 
StringType, true)))
+    }
+    assert(mismatchedInput.getMessage.contains("mismatched input '.' 
expecting"))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/17edfec5/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index f951b46..4017926 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -968,6 +968,36 @@ class JDBCSuite extends SparkFunSuite
     assert(e2.contains("User specified schema not supported with `jdbc`"))
   }
 
+  test("jdbc API support custom schema") {
+    val parts = Array[String]("THEID < 2", "THEID >= 2")
+    val props = new Properties()
+    props.put("customSchema", "NAME STRING, THEID BIGINT")
+    val schema = StructType(Seq(
+      StructField("NAME", StringType, true), StructField("THEID", LongType, 
true)))
+    val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props)
+    assert(df.schema.size === 2)
+    assert(df.schema === schema)
+    assert(df.count() === 3)
+  }
+
+  test("jdbc API custom schema DDL-like strings.") {
+    withTempView("people_view") {
+      sql(
+        s"""
+           |CREATE TEMPORARY VIEW people_view
+           |USING org.apache.spark.sql.jdbc
+           |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', 
PassWord 'testPass',
+           |customSchema 'NAME STRING, THEID INT')
+        """.stripMargin.replaceAll("\n", " "))
+      val schema = StructType(
+        Seq(StructField("NAME", StringType, true), StructField("THEID", 
IntegerType, true)))
+      val df = sql("select * from people_view")
+      assert(df.schema.size === 2)
+      assert(df.schema === schema)
+      assert(df.count() === 3)
+    }
+  }
+
   test("SPARK-15648: teradataDialect StringType data mapping") {
     val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
     assert(teradataDialect.getJDBCType(StringType).


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to