This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1099fd3  [SPARK-37115][SQL] HiveClientImpl should use shim to wrap all 
hive client calls
1099fd3 is described below

commit 1099fd342075b53ad9ddb2787911f2dabb340a3d
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Wed Oct 27 15:32:47 2021 +0800

    [SPARK-37115][SQL] HiveClientImpl should use shim to wrap all hive client 
calls
    
    ### What changes were proposed in this pull request?
    In this pr we use `shim` to wrap all hive client api to make it easier.
    
    ### Why are the changes needed?
    Use `shim` to wrap all hive client api  to make it easier.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existed UT
    
    Closes #34388 from AngersZhuuuu/SPARK-37115.
    
    Authored-by: Angerszhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/hive/client/HiveClientImpl.scala     |  64 ++++----
 .../apache/spark/sql/hive/client/HiveShim.scala    | 176 ++++++++++++++++++++-
 2 files changed, 205 insertions(+), 35 deletions(-)

diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index e295e0f..25be8b5 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -343,14 +343,14 @@ private[hive] class HiveClientImpl(
       database: CatalogDatabase,
       ignoreIfExists: Boolean): Unit = withHiveState {
     val hiveDb = toHiveDatabase(database, Some(userName))
-    client.createDatabase(hiveDb, ignoreIfExists)
+    shim.createDatabase(client, hiveDb, ignoreIfExists)
   }
 
   override def dropDatabase(
       name: String,
       ignoreIfNotExists: Boolean,
       cascade: Boolean): Unit = withHiveState {
-    client.dropDatabase(name, true, ignoreIfNotExists, cascade)
+    shim.dropDatabase(client, name, true, ignoreIfNotExists, cascade)
   }
 
   override def alterDatabase(database: CatalogDatabase): Unit = withHiveState {
@@ -361,7 +361,7 @@ private[hive] class HiveClientImpl(
       }
     }
     val hiveDb = toHiveDatabase(database)
-    client.alterDatabase(database.name, hiveDb)
+    shim.alterDatabase(client, database.name, hiveDb)
   }
 
   private def toHiveDatabase(
@@ -379,7 +379,7 @@ private[hive] class HiveClientImpl(
   }
 
   override def getDatabase(dbName: String): CatalogDatabase = withHiveState {
-    Option(client.getDatabase(dbName)).map { d =>
+    Option(shim.getDatabase(client, dbName)).map { d =>
       val params = 
Option(d.getParameters).map(_.asScala.toMap).getOrElse(Map()) ++
         Map(PROP_OWNER -> shim.getDatabaseOwnerName(d))
 
@@ -392,15 +392,15 @@ private[hive] class HiveClientImpl(
   }
 
   override def databaseExists(dbName: String): Boolean = withHiveState {
-    client.databaseExists(dbName)
+    shim.databaseExists(client, dbName)
   }
 
   override def listDatabases(pattern: String): Seq[String] = withHiveState {
-    client.getDatabasesByPattern(pattern).asScala.toSeq
+    shim.getDatabasesByPattern(client, pattern)
   }
 
   private def getRawTableOption(dbName: String, tableName: String): 
Option[HiveTable] = {
-    Option(client.getTable(dbName, tableName, false /* do not throw exception 
*/))
+    Option(shim.getTable(client, dbName, tableName, false /* do not throw 
exception */))
   }
 
   private def getRawTablesByName(dbName: String, tableNames: Seq[String]): 
Seq[HiveTable] = {
@@ -551,7 +551,7 @@ private[hive] class HiveClientImpl(
 
   override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit 
= withHiveState {
     verifyColumnDataType(table.dataSchema)
-    client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists)
+    shim.createTable(client, toHiveTable(table, Some(userName)), 
ignoreIfExists)
   }
 
   override def dropTable(
@@ -583,7 +583,7 @@ private[hive] class HiveClientImpl(
       tableName: String,
       newDataSchema: StructType,
       schemaProps: Map[String, String]): Unit = withHiveState {
-    val oldTable = client.getTable(dbName, tableName)
+    val oldTable = shim.getTable(client, dbName, tableName)
     verifyColumnDataType(newDataSchema)
     val hiveCols = newDataSchema.map(toHiveColumn)
     oldTable.setFields(hiveCols.asJava)
@@ -630,7 +630,7 @@ private[hive] class HiveClientImpl(
       purge: Boolean,
       retainData: Boolean): Unit = withHiveState {
     // TODO: figure out how to drop multiple partitions in one call
-    val hiveTable = client.getTable(db, table, true /* throw exception */)
+    val hiveTable = shim.getTable(client, db, table, true /* throw exception 
*/)
     // do the check at first and collect all the matching partitions
     val matchingParts =
       specs.flatMap { s =>
@@ -638,7 +638,7 @@ private[hive] class HiveClientImpl(
         // The provided spec here can be a partial spec, i.e. it will match 
all partitions
         // whose specs are supersets of this partial spec. E.g. If a table has 
partitions
         // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will 
match both.
-        val parts = client.getPartitions(hiveTable, s.asJava).asScala
+        val parts = shim.getPartitions(client, hiveTable, s.asJava)
         if (parts.isEmpty && !ignoreIfNotExists) {
           throw new NoSuchPartitionsException(db, table, Seq(s))
         }
@@ -677,13 +677,13 @@ private[hive] class HiveClientImpl(
     val catalogTable = getTable(db, table)
     val hiveTable = toHiveTable(catalogTable, Some(userName))
     specs.zip(newSpecs).foreach { case (oldSpec, newSpec) =>
-      if (client.getPartition(hiveTable, newSpec.asJava, false) != null) {
+      if (shim.getPartition(client, hiveTable, newSpec.asJava, false) != null) 
{
         throw new PartitionAlreadyExistsException(db, table, newSpec)
       }
       val hivePart = getPartitionOption(catalogTable, oldSpec)
         .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) }
         .getOrElse { throw new NoSuchPartitionException(db, table, oldSpec) }
-      client.renamePartition(hiveTable, oldSpec.asJava, hivePart)
+      shim.renamePartition(client, hiveTable, oldSpec.asJava, hivePart)
     }
   }
 
@@ -718,19 +718,19 @@ private[hive] class HiveClientImpl(
       partialSpec match {
         case None =>
           // -1 for result limit means "no limit/return all"
-          client.getPartitionNames(table.database, table.identifier.table, -1)
+          shim.getPartitionNames(client, table.database, 
table.identifier.table, -1)
         case Some(s) =>
           assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is 
invalid")
-          client.getPartitionNames(table.database, table.identifier.table, 
s.asJava, -1)
+          shim.getPartitionNames(client, table.database, 
table.identifier.table, s.asJava, -1)
       }
-    hivePartitionNames.asScala.sorted.toSeq
+    hivePartitionNames.sorted.toSeq
   }
 
   override def getPartitionOption(
       table: CatalogTable,
       spec: TablePartitionSpec): Option[CatalogTablePartition] = withHiveState 
{
     val hiveTable = toHiveTable(table, Some(userName))
-    val hivePartition = client.getPartition(hiveTable, spec.asJava, false)
+    val hivePartition = shim.getPartition(client, hiveTable, spec.asJava, 
false)
     Option(hivePartition).map(fromHivePartition)
   }
 
@@ -753,7 +753,7 @@ private[hive] class HiveClientImpl(
         assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid")
         s
     }
-    val parts = client.getPartitions(hiveTable, 
partSpec.asJava).asScala.map(fromHivePartition)
+    val parts = shim.getPartitions(client, hiveTable, 
partSpec.asJava).map(fromHivePartition)
     HiveCatalogMetrics.incrementFetchedPartitions(parts.length)
     parts.toSeq
   }
@@ -769,11 +769,11 @@ private[hive] class HiveClientImpl(
   }
 
   override def listTables(dbName: String): Seq[String] = withHiveState {
-    client.getAllTables(dbName).asScala.toSeq
+    shim.getAllTables(client, dbName)
   }
 
   override def listTables(dbName: String, pattern: String): Seq[String] = 
withHiveState {
-    client.getTablesByPattern(dbName, pattern).asScala.toSeq
+    shim.getTablesByPattern(client, dbName, pattern)
   }
 
   override def listTablesByType(
@@ -787,8 +787,8 @@ private[hive] class HiveClientImpl(
     } catch {
       case _: UnsupportedOperationException =>
         // Fallback to filter logic if getTablesByType not supported.
-        val tableNames = client.getTablesByPattern(dbName, pattern).asScala
-        getRawTablesByName(dbName, tableNames.toSeq)
+        val tableNames = shim.getTablesByPattern(client, dbName, pattern)
+        getRawTablesByName(dbName, tableNames)
           .filter(_.getTableType == hiveTableType)
           .map(_.getTableName)
     }
@@ -887,7 +887,7 @@ private[hive] class HiveClientImpl(
       replace: Boolean,
       inheritTableSpecs: Boolean,
       isSrcLocal: Boolean): Unit = withHiveState {
-    val hiveTable = client.getTable(dbName, tableName, true /* throw exception 
*/)
+    val hiveTable = shim.getTable(client, dbName, tableName, true /* throw 
exception */)
     shim.loadPartition(
       client,
       new Path(loadPath), // TODO: Use URI
@@ -919,7 +919,7 @@ private[hive] class HiveClientImpl(
       partSpec: java.util.LinkedHashMap[String, String],
       replace: Boolean,
       numDP: Int): Unit = withHiveState {
-    val hiveTable = client.getTable(dbName, tableName, true /* throw exception 
*/)
+    val hiveTable = shim.getTable(client, dbName, tableName, true /* throw 
exception */)
     shim.loadDynamicPartitions(
       client,
       new Path(loadPath),
@@ -965,36 +965,36 @@ private[hive] class HiveClientImpl(
   }
 
   def reset(): Unit = withHiveState {
-    val allTables = client.getAllTables("default")
-    val (mvs, others) = allTables.asScala.map(t => client.getTable("default", 
t))
+    val allTables = shim.getAllTables(client, "default")
+    val (mvs, others) = allTables.map(t => shim.getTable(client, "default", t))
       .partition(_.getTableType.toString.equals("MATERIALIZED_VIEW"))
 
     // Remove materialized view first, otherwise caused a violation of foreign 
key constraint.
     mvs.foreach { table =>
       val t = table.getTableName
       logDebug(s"Deleting materialized view $t")
-      client.dropTable("default", t)
+      shim.dropTable(client, "default", t)
     }
 
     others.foreach { table =>
       val t = table.getTableName
       logDebug(s"Deleting table $t")
       try {
-        client.getIndexes("default", t, 255).asScala.foreach { index =>
+        shim.getIndexes(client, "default", t, 255).foreach { index =>
           shim.dropIndex(client, "default", t, index.getIndexName)
         }
         if (!table.isIndexTable) {
-          client.dropTable("default", t)
+          shim.dropTable(client, "default", t)
         }
       } catch {
         case _: NoSuchMethodError =>
           // HIVE-18448 Hive 3.0 remove index APIs
-          client.dropTable("default", t)
+          shim.dropTable(client, "default", t)
       }
     }
-    client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db =>
+    shim.getAllDatabases(client).filterNot(_ == "default").foreach { db =>
       logDebug(s"Dropping Database: $db")
-      client.dropDatabase(db, true, false, true)
+      shim.dropDatabase(client, db, true, false, true)
     }
   }
 }
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 3e2742b..d76190a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.hadoop.hive.conf.HiveConf
 import org.apache.hadoop.hive.metastore.IMetaStoreClient
 import org.apache.hadoop.hive.metastore.TableType
-import org.apache.hadoop.hive.metastore.api.{Database, EnvironmentContext, 
Function => HiveFunction, FunctionType, MetaException, PrincipalType, 
ResourceType, ResourceUri}
+import org.apache.hadoop.hive.metastore.api.{Database, EnvironmentContext, 
Function => HiveFunction, FunctionType, Index, MetaException, PrincipalType, 
ResourceType, ResourceUri}
 import org.apache.hadoop.hive.ql.Driver
 import org.apache.hadoop.hive.ql.io.AcidUtils
 import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
@@ -75,6 +75,25 @@ private[client] sealed abstract class Shim {
    */
   def getDataLocation(table: Table): Option[String]
 
+  def createDatabase(hive: Hive, db: Database, ignoreIfExists: Boolean): Unit
+
+  def dropDatabase(
+      hive: Hive,
+      dbName: String,
+      deleteData: Boolean,
+      ignoreUnknownDb: Boolean,
+      cascade: Boolean): Unit
+
+  def alterDatabase(hive: Hive, dbName: String, d: Database): Unit
+
+  def getDatabase(hive: Hive, dbName: String): Database
+
+  def getAllDatabases(hive: Hive): Seq[String]
+
+  def getDatabasesByPattern(hive: Hive, pattern: String): Seq[String]
+
+  def databaseExists(hive: Hive, dbName: String): Boolean
+
   def setDataLocation(table: Table, loc: String): Unit
 
   def getAllPartitions(hive: Hive, table: Table): Seq[Partition]
@@ -94,16 +113,54 @@ private[client] sealed abstract class Shim {
 
   def alterPartitions(hive: Hive, tableName: String, newParts: 
JList[Partition]): Unit
 
+  def createTable(hive: Hive, table: Table, ifNotExists: Boolean): Unit
+
+  def getTable(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      throwException: Boolean = true): Table
+
   def getTablesByType(
       hive: Hive,
       dbName: String,
       pattern: String,
       tableType: TableType): Seq[String]
 
+  def getTablesByPattern(hive: Hive, dbName: String, pattern: String): 
Seq[String]
+
+  def getAllTables(hive: Hive, dbName: String): Seq[String]
+
+  def dropTable(hive: Hive, dbName: String, tableName: String): Unit
+
+  def getPartition(
+      hive: Hive,
+      table: Table,
+      partSpec: JMap[String, String],
+      forceCreate: Boolean): Partition
+
+  def getPartitions(
+      hive: Hive,
+      table: Table,
+      partSpec: JMap[String, String]): Seq[Partition]
+
+  def getPartitionNames(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      max: Short): Seq[String]
+
+  def getPartitionNames(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      partSpec: JMap[String, String],
+      max: Short): Seq[String]
+
   def createPartitions(
       hive: Hive,
-      db: String,
-      table: String,
+      dbName: String,
+      tableName: String,
       parts: Seq[CatalogTablePartition],
       ignoreIfExists: Boolean): Unit
 
@@ -117,6 +174,12 @@ private[client] sealed abstract class Shim {
       isSkewedStoreAsSubdir: Boolean,
       isSrcLocal: Boolean): Unit
 
+  def renamePartition(
+      hive: Hive,
+      table: Table,
+      oldPartSpec: JMap[String, String],
+      newPart: Partition): Unit
+
   def loadTable(
       hive: Hive,
       loadPath: Path,
@@ -176,6 +239,8 @@ private[client] sealed abstract class Shim {
 
   def getMSC(hive: Hive): IMetaStoreClient
 
+  def getIndexes(hive: Hive, dbName: String, tableName: String, max: Short): 
Seq[Index]
+
   protected def findMethod(klass: Class[_], name: String, args: Class[_]*): 
Method = {
     klass.getMethod(name, args: _*)
   }
@@ -480,6 +545,111 @@ private[client] class Shim_v0_12 extends Shim with 
Logging {
   override def getDatabaseOwnerName(db: Database): String = ""
 
   override def setDatabaseOwnerName(db: Database, owner: String): Unit = {}
+
+  override def createDatabase(hive: Hive, db: Database, ignoreIfExists: 
Boolean): Unit = {
+    hive.createDatabase(db, ignoreIfExists)
+  }
+
+  override def dropDatabase(
+      hive: Hive,
+      dbName: String,
+      deleteData: Boolean,
+      ignoreUnknownDb: Boolean,
+      cascade: Boolean): Unit = {
+    hive.dropDatabase(dbName, deleteData, ignoreUnknownDb, cascade)
+  }
+
+  override def alterDatabase(hive: Hive, dbName: String, d: Database): Unit = {
+    hive.alterDatabase(dbName, d)
+  }
+
+  override def getDatabase(hive: Hive, dbName: String): Database = {
+    hive.getDatabase(dbName)
+  }
+
+  override def getAllDatabases(hive: Hive): Seq[String] = {
+    hive.getAllDatabases.asScala.toSeq
+  }
+
+  override def getDatabasesByPattern(hive: Hive, pattern: String): Seq[String] 
= {
+    hive.getDatabasesByPattern(pattern).asScala.toSeq
+  }
+
+  override def databaseExists(hive: Hive, dbName: String): Boolean = {
+    hive.databaseExists(dbName)
+  }
+
+  override def createTable(hive: Hive, table: Table, ifNotExists: Boolean): 
Unit = {
+    hive.createTable(table, ifNotExists)
+  }
+
+  override def getTable(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      throwException: Boolean): Table = {
+    hive.getTable(dbName, tableName, throwException)
+  }
+
+  override def getTablesByPattern(hive: Hive, dbName: String, pattern: 
String): Seq[String] = {
+    hive.getTablesByPattern(dbName, pattern).asScala.toSeq
+  }
+
+  override def getAllTables(hive: Hive, dbName: String): Seq[String] = {
+    hive.getAllTables(dbName).asScala.toSeq
+  }
+
+  override def dropTable(hive: Hive, dbName: String, tableName: String): Unit 
= {
+    hive.dropTable(dbName, tableName)
+  }
+
+  override def getPartition(
+      hive: Hive,
+      table: Table,
+      partSpec: JMap[String, String],
+      forceCreate: Boolean): Partition = {
+    hive.getPartition(table, partSpec, forceCreate)
+  }
+
+  override def getPartitions(
+      hive: Hive,
+      table: Table,
+      partSpec: JMap[String, String]): Seq[Partition] = {
+    hive.getPartitions(table, partSpec).asScala.toSeq
+  }
+
+  override def getPartitionNames(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      max: Short): Seq[String] = {
+    hive.getPartitionNames(dbName, tableName, max).asScala.toSeq
+  }
+
+  override def getPartitionNames(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      partSpec: JMap[String, String],
+      max: Short): Seq[String] = {
+    hive.getPartitionNames(dbName, tableName, partSpec, max).asScala.toSeq
+  }
+
+  override def renamePartition(
+      hive: Hive,
+      table: Table,
+      oldPartSpec: JMap[String, String],
+      newPart: Partition): Unit = {
+    hive.renamePartition(table, oldPartSpec, newPart)
+  }
+
+  override def getIndexes(
+      hive: Hive,
+      dbName: String,
+      tableName: String,
+      max: Short): Seq[Index] = {
+    hive.getIndexes(dbName, tableName, max).asScala.toSeq
+  }
 }
 
 private[client] class Shim_v0_13 extends Shim_v0_12 {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to