This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1099fd3 [SPARK-37115][SQL] HiveClientImpl should use shim to wrap all hive client calls 1099fd3 is described below commit 1099fd342075b53ad9ddb2787911f2dabb340a3d Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Wed Oct 27 15:32:47 2021 +0800 [SPARK-37115][SQL] HiveClientImpl should use shim to wrap all hive client calls ### What changes were proposed in this pull request? In this pr we use `shim` to wrap all hive client api to make it easier. ### Why are the changes needed? Use `shim` to wrap all hive client api to make it easier. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existed UT Closes #34388 from AngersZhuuuu/SPARK-37115. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/hive/client/HiveClientImpl.scala | 64 ++++---- .../apache/spark/sql/hive/client/HiveShim.scala | 176 ++++++++++++++++++++- 2 files changed, 205 insertions(+), 35 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index e295e0f..25be8b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -343,14 +343,14 @@ private[hive] class HiveClientImpl( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { val hiveDb = toHiveDatabase(database, Some(userName)) - client.createDatabase(hiveDb, ignoreIfExists) + shim.createDatabase(client, hiveDb, ignoreIfExists) } override def dropDatabase( name: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withHiveState { - client.dropDatabase(name, true, ignoreIfNotExists, cascade) + shim.dropDatabase(client, name, true, ignoreIfNotExists, cascade) } override def alterDatabase(database: CatalogDatabase): Unit = withHiveState { @@ -361,7 +361,7 @@ private[hive] class HiveClientImpl( } } val hiveDb = toHiveDatabase(database) - client.alterDatabase(database.name, hiveDb) + shim.alterDatabase(client, database.name, hiveDb) } private def toHiveDatabase( @@ -379,7 +379,7 @@ private[hive] class HiveClientImpl( } override def getDatabase(dbName: String): CatalogDatabase = withHiveState { - Option(client.getDatabase(dbName)).map { d => + Option(shim.getDatabase(client, dbName)).map { d => val params = Option(d.getParameters).map(_.asScala.toMap).getOrElse(Map()) ++ Map(PROP_OWNER -> shim.getDatabaseOwnerName(d)) @@ -392,15 +392,15 @@ private[hive] class HiveClientImpl( } override def databaseExists(dbName: String): Boolean = withHiveState { - client.databaseExists(dbName) + shim.databaseExists(client, dbName) } override def listDatabases(pattern: String): Seq[String] = withHiveState { - client.getDatabasesByPattern(pattern).asScala.toSeq + shim.getDatabasesByPattern(client, pattern) } private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + Option(shim.getTable(client, dbName, tableName, false /* do not throw exception */)) } private def getRawTablesByName(dbName: String, tableNames: Seq[String]): Seq[HiveTable] = { @@ -551,7 +551,7 @@ private[hive] class HiveClientImpl( override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { verifyColumnDataType(table.dataSchema) - client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists) + shim.createTable(client, toHiveTable(table, Some(userName)), ignoreIfExists) } override def dropTable( @@ -583,7 +583,7 @@ private[hive] class HiveClientImpl( tableName: String, newDataSchema: StructType, schemaProps: Map[String, String]): Unit = withHiveState { - val oldTable = client.getTable(dbName, tableName) + val oldTable = shim.getTable(client, dbName, tableName) verifyColumnDataType(newDataSchema) val hiveCols = newDataSchema.map(toHiveColumn) oldTable.setFields(hiveCols.asJava) @@ -630,7 +630,7 @@ private[hive] class HiveClientImpl( purge: Boolean, retainData: Boolean): Unit = withHiveState { // TODO: figure out how to drop multiple partitions in one call - val hiveTable = client.getTable(db, table, true /* throw exception */) + val hiveTable = shim.getTable(client, db, table, true /* throw exception */) // do the check at first and collect all the matching partitions val matchingParts = specs.flatMap { s => @@ -638,7 +638,7 @@ private[hive] class HiveClientImpl( // The provided spec here can be a partial spec, i.e. it will match all partitions // whose specs are supersets of this partial spec. E.g. If a table has partitions // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both. - val parts = client.getPartitions(hiveTable, s.asJava).asScala + val parts = shim.getPartitions(client, hiveTable, s.asJava) if (parts.isEmpty && !ignoreIfNotExists) { throw new NoSuchPartitionsException(db, table, Seq(s)) } @@ -677,13 +677,13 @@ private[hive] class HiveClientImpl( val catalogTable = getTable(db, table) val hiveTable = toHiveTable(catalogTable, Some(userName)) specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => - if (client.getPartition(hiveTable, newSpec.asJava, false) != null) { + if (shim.getPartition(client, hiveTable, newSpec.asJava, false) != null) { throw new PartitionAlreadyExistsException(db, table, newSpec) } val hivePart = getPartitionOption(catalogTable, oldSpec) .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } .getOrElse { throw new NoSuchPartitionException(db, table, oldSpec) } - client.renamePartition(hiveTable, oldSpec.asJava, hivePart) + shim.renamePartition(client, hiveTable, oldSpec.asJava, hivePart) } } @@ -718,19 +718,19 @@ private[hive] class HiveClientImpl( partialSpec match { case None => // -1 for result limit means "no limit/return all" - client.getPartitionNames(table.database, table.identifier.table, -1) + shim.getPartitionNames(client, table.database, table.identifier.table, -1) case Some(s) => assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") - client.getPartitionNames(table.database, table.identifier.table, s.asJava, -1) + shim.getPartitionNames(client, table.database, table.identifier.table, s.asJava, -1) } - hivePartitionNames.asScala.sorted.toSeq + hivePartitionNames.sorted.toSeq } override def getPartitionOption( table: CatalogTable, spec: TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val hivePartition = client.getPartition(hiveTable, spec.asJava, false) + val hivePartition = shim.getPartition(client, hiveTable, spec.asJava, false) Option(hivePartition).map(fromHivePartition) } @@ -753,7 +753,7 @@ private[hive] class HiveClientImpl( assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") s } - val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) + val parts = shim.getPartitions(client, hiveTable, partSpec.asJava).map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts.toSeq } @@ -769,11 +769,11 @@ private[hive] class HiveClientImpl( } override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName).asScala.toSeq + shim.getAllTables(client, dbName) } override def listTables(dbName: String, pattern: String): Seq[String] = withHiveState { - client.getTablesByPattern(dbName, pattern).asScala.toSeq + shim.getTablesByPattern(client, dbName, pattern) } override def listTablesByType( @@ -787,8 +787,8 @@ private[hive] class HiveClientImpl( } catch { case _: UnsupportedOperationException => // Fallback to filter logic if getTablesByType not supported. - val tableNames = client.getTablesByPattern(dbName, pattern).asScala - getRawTablesByName(dbName, tableNames.toSeq) + val tableNames = shim.getTablesByPattern(client, dbName, pattern) + getRawTablesByName(dbName, tableNames) .filter(_.getTableType == hiveTableType) .map(_.getTableName) } @@ -887,7 +887,7 @@ private[hive] class HiveClientImpl( replace: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = withHiveState { - val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) + val hiveTable = shim.getTable(client, dbName, tableName, true /* throw exception */) shim.loadPartition( client, new Path(loadPath), // TODO: Use URI @@ -919,7 +919,7 @@ private[hive] class HiveClientImpl( partSpec: java.util.LinkedHashMap[String, String], replace: Boolean, numDP: Int): Unit = withHiveState { - val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) + val hiveTable = shim.getTable(client, dbName, tableName, true /* throw exception */) shim.loadDynamicPartitions( client, new Path(loadPath), @@ -965,36 +965,36 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - val allTables = client.getAllTables("default") - val (mvs, others) = allTables.asScala.map(t => client.getTable("default", t)) + val allTables = shim.getAllTables(client, "default") + val (mvs, others) = allTables.map(t => shim.getTable(client, "default", t)) .partition(_.getTableType.toString.equals("MATERIALIZED_VIEW")) // Remove materialized view first, otherwise caused a violation of foreign key constraint. mvs.foreach { table => val t = table.getTableName logDebug(s"Deleting materialized view $t") - client.dropTable("default", t) + shim.dropTable(client, "default", t) } others.foreach { table => val t = table.getTableName logDebug(s"Deleting table $t") try { - client.getIndexes("default", t, 255).asScala.foreach { index => + shim.getIndexes(client, "default", t, 255).foreach { index => shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { - client.dropTable("default", t) + shim.dropTable(client, "default", t) } } catch { case _: NoSuchMethodError => // HIVE-18448 Hive 3.0 remove index APIs - client.dropTable("default", t) + shim.dropTable(client, "default", t) } } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + shim.getAllDatabases(client).filterNot(_ == "default").foreach { db => logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) + shim.dropDatabase(client, db, true, false, true) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 3e2742b..d76190a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.IMetaStoreClient import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.metastore.api.{Database, EnvironmentContext, Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{Database, EnvironmentContext, Function => HiveFunction, FunctionType, Index, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.io.AcidUtils import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -75,6 +75,25 @@ private[client] sealed abstract class Shim { */ def getDataLocation(table: Table): Option[String] + def createDatabase(hive: Hive, db: Database, ignoreIfExists: Boolean): Unit + + def dropDatabase( + hive: Hive, + dbName: String, + deleteData: Boolean, + ignoreUnknownDb: Boolean, + cascade: Boolean): Unit + + def alterDatabase(hive: Hive, dbName: String, d: Database): Unit + + def getDatabase(hive: Hive, dbName: String): Database + + def getAllDatabases(hive: Hive): Seq[String] + + def getDatabasesByPattern(hive: Hive, pattern: String): Seq[String] + + def databaseExists(hive: Hive, dbName: String): Boolean + def setDataLocation(table: Table, loc: String): Unit def getAllPartitions(hive: Hive, table: Table): Seq[Partition] @@ -94,16 +113,54 @@ private[client] sealed abstract class Shim { def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit + def createTable(hive: Hive, table: Table, ifNotExists: Boolean): Unit + + def getTable( + hive: Hive, + dbName: String, + tableName: String, + throwException: Boolean = true): Table + def getTablesByType( hive: Hive, dbName: String, pattern: String, tableType: TableType): Seq[String] + def getTablesByPattern(hive: Hive, dbName: String, pattern: String): Seq[String] + + def getAllTables(hive: Hive, dbName: String): Seq[String] + + def dropTable(hive: Hive, dbName: String, tableName: String): Unit + + def getPartition( + hive: Hive, + table: Table, + partSpec: JMap[String, String], + forceCreate: Boolean): Partition + + def getPartitions( + hive: Hive, + table: Table, + partSpec: JMap[String, String]): Seq[Partition] + + def getPartitionNames( + hive: Hive, + dbName: String, + tableName: String, + max: Short): Seq[String] + + def getPartitionNames( + hive: Hive, + dbName: String, + tableName: String, + partSpec: JMap[String, String], + max: Short): Seq[String] + def createPartitions( hive: Hive, - db: String, - table: String, + dbName: String, + tableName: String, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit @@ -117,6 +174,12 @@ private[client] sealed abstract class Shim { isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit + def renamePartition( + hive: Hive, + table: Table, + oldPartSpec: JMap[String, String], + newPart: Partition): Unit + def loadTable( hive: Hive, loadPath: Path, @@ -176,6 +239,8 @@ private[client] sealed abstract class Shim { def getMSC(hive: Hive): IMetaStoreClient + def getIndexes(hive: Hive, dbName: String, tableName: String, max: Short): Seq[Index] + protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = { klass.getMethod(name, args: _*) } @@ -480,6 +545,111 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getDatabaseOwnerName(db: Database): String = "" override def setDatabaseOwnerName(db: Database, owner: String): Unit = {} + + override def createDatabase(hive: Hive, db: Database, ignoreIfExists: Boolean): Unit = { + hive.createDatabase(db, ignoreIfExists) + } + + override def dropDatabase( + hive: Hive, + dbName: String, + deleteData: Boolean, + ignoreUnknownDb: Boolean, + cascade: Boolean): Unit = { + hive.dropDatabase(dbName, deleteData, ignoreUnknownDb, cascade) + } + + override def alterDatabase(hive: Hive, dbName: String, d: Database): Unit = { + hive.alterDatabase(dbName, d) + } + + override def getDatabase(hive: Hive, dbName: String): Database = { + hive.getDatabase(dbName) + } + + override def getAllDatabases(hive: Hive): Seq[String] = { + hive.getAllDatabases.asScala.toSeq + } + + override def getDatabasesByPattern(hive: Hive, pattern: String): Seq[String] = { + hive.getDatabasesByPattern(pattern).asScala.toSeq + } + + override def databaseExists(hive: Hive, dbName: String): Boolean = { + hive.databaseExists(dbName) + } + + override def createTable(hive: Hive, table: Table, ifNotExists: Boolean): Unit = { + hive.createTable(table, ifNotExists) + } + + override def getTable( + hive: Hive, + dbName: String, + tableName: String, + throwException: Boolean): Table = { + hive.getTable(dbName, tableName, throwException) + } + + override def getTablesByPattern(hive: Hive, dbName: String, pattern: String): Seq[String] = { + hive.getTablesByPattern(dbName, pattern).asScala.toSeq + } + + override def getAllTables(hive: Hive, dbName: String): Seq[String] = { + hive.getAllTables(dbName).asScala.toSeq + } + + override def dropTable(hive: Hive, dbName: String, tableName: String): Unit = { + hive.dropTable(dbName, tableName) + } + + override def getPartition( + hive: Hive, + table: Table, + partSpec: JMap[String, String], + forceCreate: Boolean): Partition = { + hive.getPartition(table, partSpec, forceCreate) + } + + override def getPartitions( + hive: Hive, + table: Table, + partSpec: JMap[String, String]): Seq[Partition] = { + hive.getPartitions(table, partSpec).asScala.toSeq + } + + override def getPartitionNames( + hive: Hive, + dbName: String, + tableName: String, + max: Short): Seq[String] = { + hive.getPartitionNames(dbName, tableName, max).asScala.toSeq + } + + override def getPartitionNames( + hive: Hive, + dbName: String, + tableName: String, + partSpec: JMap[String, String], + max: Short): Seq[String] = { + hive.getPartitionNames(dbName, tableName, partSpec, max).asScala.toSeq + } + + override def renamePartition( + hive: Hive, + table: Table, + oldPartSpec: JMap[String, String], + newPart: Partition): Unit = { + hive.renamePartition(table, oldPartSpec, newPart) + } + + override def getIndexes( + hive: Hive, + dbName: String, + tableName: String, + max: Short): Seq[Index] = { + hive.getIndexes(dbName, tableName, max).asScala.toSeq + } } private[client] class Shim_v0_13 extends Shim_v0_12 { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org