This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 702180614900 [SPARK-47754][SQL] Postgres: Support reading 
multidimensional arrays
702180614900 is described below

commit 702180614900bdaf245a194da0043b8b51de3b4b
Author: Kent Yao <y...@apache.org>
AuthorDate: Mon Apr 8 22:45:21 2024 -0700

    [SPARK-47754][SQL] Postgres: Support reading multidimensional arrays
    
    ### What changes were proposed in this pull request?
    
    Because the ResultSetMetadata cannot distinguish a single-dimensional array 
from multidimensional arrays. Thus, we always read multidimensional arrays as 
single-dimensional ones, For example, `text[][]` is mapping to 
`ArrayType(StringType)` and `int[][][]` is `ArrayType(IntegerType)`, this 
result in errors when converting a ResultSet with multidimensional arrays to 
InternalRows.
    
    This PR supports reading multidimensional arrays from PostgreSQL data 
sources. To achieve this, the simplest way is to add a new developer API to 
retrieve it from the information schema of Postgres.
    
    
https://www.postgresql.org/docs/16/catalog-pg-attribute.html#CATALOG-PG-ATTRIBUTE
    
    It is possible to use functions like `array_dims` to retrieve the dimension 
of an array column, but it is not easy to inject without causing breaking 
changes or to determine the dimension based on the actual data.
    ### Why are the changes needed?
    
    We have supported writing multidimensional arrays to Postgres, so we shall 
improve postgres reading abilities too.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #45917 from yaooqinn/SPARK-47754.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/jdbc/PostgresIntegrationSuite.scala  | 21 ++++++-------
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   |  2 +-
 .../sql/execution/datasources/jdbc/JdbcUtils.scala | 18 ++++++++---
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 16 ++++++++++
 .../apache/spark/sql/jdbc/PostgresDialect.scala    | 36 +++++++++++++++++++++-
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala      |  3 +-
 6 files changed, 76 insertions(+), 20 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 69573e9bddb1..1cd8a77e8442 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -23,7 +23,6 @@ import java.text.SimpleDateFormat
 import java.time.LocalDateTime
 import java.util.Properties
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.{Column, Row}
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.types._
@@ -514,19 +513,17 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationSuite {
 
     sql("select array(array(1, 2), array(3, 4)) as col0").write
       .jdbc(jdbcUrl, "double_dim_array", new Properties)
+
+    checkAnswer(
+      spark.read.jdbc(jdbcUrl, "double_dim_array", new Properties),
+      Row(Seq(Seq(1, 2), Seq(3, 4))))
+
     sql("select array(array(array(1, 2), array(3, 4)), array(array(5, 6), 
array(7, 8))) as col0")
       .write.jdbc(jdbcUrl, "triple_dim_array", new Properties)
-    // Reading multi-dimensional array is not supported yet.
-    checkError(
-      exception = intercept[SparkException] {
-        spark.read.jdbc(jdbcUrl, "double_dim_array", new Properties).collect()
-      },
-      errorClass = null)
-    checkError(
-      exception = intercept[SparkException] {
-        spark.read.jdbc(jdbcUrl, "triple_dim_array", new Properties).collect()
-      },
-      errorClass = null)
+
+    checkAnswer(
+      spark.read.jdbc(jdbcUrl, "triple_dim_array", new Properties),
+      Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6), Seq(7, 8)))))
   }
 
   test("SPARK-47701: Reading complex type") {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 7eff4bd376bc..8c430e231e39 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -67,7 +67,7 @@ object JDBCRDD extends Logging {
       Using.resource(conn.prepareStatement(query)) { statement =>
         statement.setQueryTimeout(options.queryTimeout)
         Using.resource(statement.executeQuery()) { rs =>
-          JdbcUtils.getSchema(rs, dialect, alwaysNullable = true,
+          JdbcUtils.getSchema(conn, rs, dialect, alwaysNullable = true,
             isTimestampNTZ = options.preferTimestampNTZ)
         }
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 2d75f9a75a2c..08313f26a877 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -246,7 +246,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
         conn.prepareStatement(options.prepareQuery + 
dialect.getSchemaQuery(options.tableOrQuery))
       try {
         statement.setQueryTimeout(options.queryTimeout)
-        Some(getSchema(statement.executeQuery(), dialect,
+        Some(getSchema(conn, statement.executeQuery(), dialect,
           isTimestampNTZ = options.preferTimestampNTZ))
       } catch {
         case _: SQLException => None
@@ -267,6 +267,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
    * @throws SQLException if the schema contains an unsupported type.
    */
   def getSchema(
+      conn: Connection,
       resultSet: ResultSet,
       dialect: JdbcDialect,
       alwaysNullable: Boolean = false,
@@ -306,6 +307,11 @@ object JdbcUtils extends Logging with SQLConfHelper {
           metadata.putBoolean("logical_time_type", true)
         case java.sql.Types.ROWID =>
           metadata.putBoolean("rowid", true)
+        case java.sql.Types.ARRAY =>
+          val tableName = rsmd.getTableName(i + 1)
+          dialect.getArrayDimension(conn, tableName, columnName).foreach { 
dimension =>
+            metadata.putLong("arrayDimension", dimension)
+          }
         case _ =>
       }
       metadata.putBoolean("isSigned", isSigned)
@@ -542,7 +548,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
         }
 
     case ArrayType(et, _) =>
-      val elementConversion = et match {
+      def elementConversion(et: DataType): AnyRef => Any = et match {
         case TimestampType => arrayConverter[Timestamp] {
           (t: Timestamp) => 
fromJavaTimestamp(dialect.convertJavaTimestampToTimestamp(t))
         }
@@ -565,8 +571,10 @@ object JdbcUtils extends Logging with SQLConfHelper {
         case LongType if metadata.contains("binarylong") =>
           throw 
QueryExecutionErrors.unsupportedArrayElementTypeBasedOnBinaryError(dt)
 
-        case ArrayType(_, _) =>
-          throw QueryExecutionErrors.nestedArraysUnsupportedError()
+        case ArrayType(et0, _) =>
+          arrayConverter[Array[Any]] {
+            arr => new GenericArrayData(elementConversion(et0)(arr))
+          }
 
         case _ => (array: Object) => array.asInstanceOf[Array[Any]]
       }
@@ -574,7 +582,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
       (rs: ResultSet, row: InternalRow, pos: Int) =>
         val array = nullSafeConvert[java.sql.Array](
           input = rs.getArray(pos + 1),
-          array => new GenericArrayData(elementConversion(array.getArray)))
+          array => new GenericArrayData(elementConversion(et)(array.getArray)))
         row.update(pos, array)
 
     case NullType =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 4367ed2a79d4..d800cc6a8617 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -798,6 +798,22 @@ abstract class JdbcDialect extends Serializable with 
Logging {
   protected final def getTimestampType(md: Metadata): DataType = {
     JdbcUtils.getTimestampType(md.getBoolean("isTimestampNTZ"))
   }
+
+  /**
+   * Return the array dimension of the column. The array dimension will be 
carried in the
+   * metadata of the column and used by `getCatalystType` to determine the 
dimension of the
+   * ArrayType.
+   *
+   * @param conn The connection currently connection being used.
+   * @param tableName The name of the table which the column belongs to.
+   * @param columnName The name of the column.
+   * @return An Option[Int] which contains the number of array dimension.
+   *         If Some(n), the column is an array with n dimensions.
+   *         If the method is un-implemented, or some error encountered, 
return None.
+   *         Then, `getCatalystType` will try use 1 dimension as default for 
arrays.
+   */
+  @Since("4.0.0")
+  def getArrayDimension(conn: Connection, tableName: String, columnName: 
String): Option[Int] = None
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 4b6b79efcc03..b9c39b467e8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -22,6 +22,8 @@ import java.time.{LocalDateTime, ZoneOffset}
 import java.util
 import java.util.Locale
 
+import scala.util.Using
+
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, 
NonEmptyNamespaceException, NoSuchIndexException}
@@ -74,7 +76,16 @@ private case class PostgresDialect() extends JdbcDialect 
with SQLConfHelper {
       case _ if "text".equalsIgnoreCase(typeName) => Some(StringType) // 
sqlType is Types.VARCHAR
       case Types.ARRAY =>
         // postgres array type names start with underscore
-        toCatalystType(typeName.drop(1), size, md).map(ArrayType(_))
+        val elementType = toCatalystType(typeName.drop(1), size, md)
+        elementType.map { et =>
+          val metadata = md.build()
+          val dim = if (metadata.contains("arrayDimension")) {
+            metadata.getLong("arrayDimension").toInt
+          } else {
+            1
+          }
+          (0 until dim).foldLeft(et)((acc, _) => ArrayType(acc))
+        }
       case _ => None
     }
   }
@@ -331,4 +342,27 @@ private case class PostgresDialect() extends JdbcDialect 
with SQLConfHelper {
       case _ => d
     }
   }
+
+  override def getArrayDimension(
+      conn: Connection,
+      tableName: String,
+      columnName: String): Option[Int] = {
+    val query =
+      s"""
+         |SELECT pg_attribute.attndims
+         |FROM pg_attribute
+         |  JOIN pg_class ON pg_attribute.attrelid = pg_class.oid
+         |  JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid
+         |WHERE pg_class.relname = '$tableName' and pg_attribute.attname = 
'$columnName'
+         |""".stripMargin
+    try {
+      Using.resource(conn.createStatement()) { stmt =>
+        Using.resource(stmt.executeQuery(query)) { rs =>
+          if (rs.next()) Some(rs.getInt(1)) else None
+        }
+      }
+    } catch {
+      case _: SQLException => None
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 56c704d8adb6..5e387a3f0791 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -2012,6 +2012,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
     when(mockRsmd.isSigned(anyInt())).thenReturn(false)
     
when(mockRsmd.isNullable(anyInt())).thenReturn(java.sql.ResultSetMetaData.columnNoNulls)
 
+    val mockConn = mock(classOf[java.sql.Connection])
     val mockRs = mock(classOf[java.sql.ResultSet])
     when(mockRs.getMetaData).thenReturn(mockRsmd)
 
@@ -2019,7 +2020,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
     when(mockDialect.getCatalystType(anyInt(), anyString(), anyInt(), 
any[MetadataBuilder]))
       .thenReturn(None)
 
-    val schema = JdbcUtils.getSchema(mockRs, mockDialect)
+    val schema = JdbcUtils.getSchema(mockConn, mockRs, mockDialect)
     val fields = schema.fields
     assert(fields.length === 1)
     assert(fields(0).dataType === StringType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to