This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 702180614900 [SPARK-47754][SQL] Postgres: Support reading multidimensional arrays 702180614900 is described below commit 702180614900bdaf245a194da0043b8b51de3b4b Author: Kent Yao <y...@apache.org> AuthorDate: Mon Apr 8 22:45:21 2024 -0700 [SPARK-47754][SQL] Postgres: Support reading multidimensional arrays ### What changes were proposed in this pull request? Because the ResultSetMetadata cannot distinguish a single-dimensional array from multidimensional arrays. Thus, we always read multidimensional arrays as single-dimensional ones, For example, `text[][]` is mapping to `ArrayType(StringType)` and `int[][][]` is `ArrayType(IntegerType)`, this result in errors when converting a ResultSet with multidimensional arrays to InternalRows. This PR supports reading multidimensional arrays from PostgreSQL data sources. To achieve this, the simplest way is to add a new developer API to retrieve it from the information schema of Postgres. https://www.postgresql.org/docs/16/catalog-pg-attribute.html#CATALOG-PG-ATTRIBUTE It is possible to use functions like `array_dims` to retrieve the dimension of an array column, but it is not easy to inject without causing breaking changes or to determine the dimension based on the actual data. ### Why are the changes needed? We have supported writing multidimensional arrays to Postgres, so we shall improve postgres reading abilities too. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45917 from yaooqinn/SPARK-47754. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 21 ++++++------- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 18 ++++++++--- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 16 ++++++++++ .../apache/spark/sql/jdbc/PostgresDialect.scala | 36 +++++++++++++++++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 3 +- 6 files changed, 76 insertions(+), 20 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 69573e9bddb1..1cd8a77e8442 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -23,7 +23,6 @@ import java.text.SimpleDateFormat import java.time.LocalDateTime import java.util.Properties -import org.apache.spark.SparkException import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ @@ -514,19 +513,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { sql("select array(array(1, 2), array(3, 4)) as col0").write .jdbc(jdbcUrl, "double_dim_array", new Properties) + + checkAnswer( + spark.read.jdbc(jdbcUrl, "double_dim_array", new Properties), + Row(Seq(Seq(1, 2), Seq(3, 4)))) + sql("select array(array(array(1, 2), array(3, 4)), array(array(5, 6), array(7, 8))) as col0") .write.jdbc(jdbcUrl, "triple_dim_array", new Properties) - // Reading multi-dimensional array is not supported yet. - checkError( - exception = intercept[SparkException] { - spark.read.jdbc(jdbcUrl, "double_dim_array", new Properties).collect() - }, - errorClass = null) - checkError( - exception = intercept[SparkException] { - spark.read.jdbc(jdbcUrl, "triple_dim_array", new Properties).collect() - }, - errorClass = null) + + checkAnswer( + spark.read.jdbc(jdbcUrl, "triple_dim_array", new Properties), + Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6), Seq(7, 8))))) } test("SPARK-47701: Reading complex type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 7eff4bd376bc..8c430e231e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -67,7 +67,7 @@ object JDBCRDD extends Logging { Using.resource(conn.prepareStatement(query)) { statement => statement.setQueryTimeout(options.queryTimeout) Using.resource(statement.executeQuery()) { rs => - JdbcUtils.getSchema(rs, dialect, alwaysNullable = true, + JdbcUtils.getSchema(conn, rs, dialect, alwaysNullable = true, isTimestampNTZ = options.preferTimestampNTZ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 2d75f9a75a2c..08313f26a877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -246,7 +246,7 @@ object JdbcUtils extends Logging with SQLConfHelper { conn.prepareStatement(options.prepareQuery + dialect.getSchemaQuery(options.tableOrQuery)) try { statement.setQueryTimeout(options.queryTimeout) - Some(getSchema(statement.executeQuery(), dialect, + Some(getSchema(conn, statement.executeQuery(), dialect, isTimestampNTZ = options.preferTimestampNTZ)) } catch { case _: SQLException => None @@ -267,6 +267,7 @@ object JdbcUtils extends Logging with SQLConfHelper { * @throws SQLException if the schema contains an unsupported type. */ def getSchema( + conn: Connection, resultSet: ResultSet, dialect: JdbcDialect, alwaysNullable: Boolean = false, @@ -306,6 +307,11 @@ object JdbcUtils extends Logging with SQLConfHelper { metadata.putBoolean("logical_time_type", true) case java.sql.Types.ROWID => metadata.putBoolean("rowid", true) + case java.sql.Types.ARRAY => + val tableName = rsmd.getTableName(i + 1) + dialect.getArrayDimension(conn, tableName, columnName).foreach { dimension => + metadata.putLong("arrayDimension", dimension) + } case _ => } metadata.putBoolean("isSigned", isSigned) @@ -542,7 +548,7 @@ object JdbcUtils extends Logging with SQLConfHelper { } case ArrayType(et, _) => - val elementConversion = et match { + def elementConversion(et: DataType): AnyRef => Any = et match { case TimestampType => arrayConverter[Timestamp] { (t: Timestamp) => fromJavaTimestamp(dialect.convertJavaTimestampToTimestamp(t)) } @@ -565,8 +571,10 @@ object JdbcUtils extends Logging with SQLConfHelper { case LongType if metadata.contains("binarylong") => throw QueryExecutionErrors.unsupportedArrayElementTypeBasedOnBinaryError(dt) - case ArrayType(_, _) => - throw QueryExecutionErrors.nestedArraysUnsupportedError() + case ArrayType(et0, _) => + arrayConverter[Array[Any]] { + arr => new GenericArrayData(elementConversion(et0)(arr)) + } case _ => (array: Object) => array.asInstanceOf[Array[Any]] } @@ -574,7 +582,7 @@ object JdbcUtils extends Logging with SQLConfHelper { (rs: ResultSet, row: InternalRow, pos: Int) => val array = nullSafeConvert[java.sql.Array]( input = rs.getArray(pos + 1), - array => new GenericArrayData(elementConversion(array.getArray))) + array => new GenericArrayData(elementConversion(et)(array.getArray))) row.update(pos, array) case NullType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 4367ed2a79d4..d800cc6a8617 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -798,6 +798,22 @@ abstract class JdbcDialect extends Serializable with Logging { protected final def getTimestampType(md: Metadata): DataType = { JdbcUtils.getTimestampType(md.getBoolean("isTimestampNTZ")) } + + /** + * Return the array dimension of the column. The array dimension will be carried in the + * metadata of the column and used by `getCatalystType` to determine the dimension of the + * ArrayType. + * + * @param conn The connection currently connection being used. + * @param tableName The name of the table which the column belongs to. + * @param columnName The name of the column. + * @return An Option[Int] which contains the number of array dimension. + * If Some(n), the column is an array with n dimensions. + * If the method is un-implemented, or some error encountered, return None. + * Then, `getCatalystType` will try use 1 dimension as default for arrays. + */ + @Since("4.0.0") + def getArrayDimension(conn: Connection, tableName: String, columnName: String): Option[Int] = None } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4b6b79efcc03..b9c39b467e8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -22,6 +22,8 @@ import java.time.{LocalDateTime, ZoneOffset} import java.util import java.util.Locale +import scala.util.Using + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} @@ -74,7 +76,16 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { case _ if "text".equalsIgnoreCase(typeName) => Some(StringType) // sqlType is Types.VARCHAR case Types.ARRAY => // postgres array type names start with underscore - toCatalystType(typeName.drop(1), size, md).map(ArrayType(_)) + val elementType = toCatalystType(typeName.drop(1), size, md) + elementType.map { et => + val metadata = md.build() + val dim = if (metadata.contains("arrayDimension")) { + metadata.getLong("arrayDimension").toInt + } else { + 1 + } + (0 until dim).foldLeft(et)((acc, _) => ArrayType(acc)) + } case _ => None } } @@ -331,4 +342,27 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { case _ => d } } + + override def getArrayDimension( + conn: Connection, + tableName: String, + columnName: String): Option[Int] = { + val query = + s""" + |SELECT pg_attribute.attndims + |FROM pg_attribute + | JOIN pg_class ON pg_attribute.attrelid = pg_class.oid + | JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid + |WHERE pg_class.relname = '$tableName' and pg_attribute.attname = '$columnName' + |""".stripMargin + try { + Using.resource(conn.createStatement()) { stmt => + Using.resource(stmt.executeQuery(query)) { rs => + if (rs.next()) Some(rs.getInt(1)) else None + } + } + } catch { + case _: SQLException => None + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 56c704d8adb6..5e387a3f0791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -2012,6 +2012,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { when(mockRsmd.isSigned(anyInt())).thenReturn(false) when(mockRsmd.isNullable(anyInt())).thenReturn(java.sql.ResultSetMetaData.columnNoNulls) + val mockConn = mock(classOf[java.sql.Connection]) val mockRs = mock(classOf[java.sql.ResultSet]) when(mockRs.getMetaData).thenReturn(mockRsmd) @@ -2019,7 +2020,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { when(mockDialect.getCatalystType(anyInt(), anyString(), anyInt(), any[MetadataBuilder])) .thenReturn(None) - val schema = JdbcUtils.getSchema(mockRs, mockDialect) + val schema = JdbcUtils.getSchema(mockConn, mockRs, mockDialect) val fields = schema.fields assert(fields.length === 1) assert(fields(0).dataType === StringType) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org