Repository: spark
Updated Branches:
  refs/heads/branch-1.5 330961bbf -> b767ceeb2


[SPARK-11191][SPARK-11311][SQL] Backports #9664 and #9277 to branch-1.5

The main purpose of this PR is to backport #9664, which depends on #9277.

Author: Cheng Lian <l...@databricks.com>

Closes #9671 from liancheng/spark-11191.fix-temp-function.branch-1_5.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b767ceeb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b767ceeb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b767ceeb

Branch: refs/heads/branch-1.5
Commit: b767ceeb25e9b4948c93475c6382b16f26dbfc6e
Parents: 330961b
Author: Cheng Lian <l...@databricks.com>
Authored: Sun Nov 15 13:16:49 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Sun Nov 15 13:16:49 2015 -0800

----------------------------------------------------------------------
 .../thriftserver/HiveThriftServer2Suites.scala  | 53 ++++++++++++++++++--
 .../org/apache/spark/sql/hive/HiveContext.scala | 18 ++++---
 .../spark/sql/hive/client/ClientWrapper.scala   |  2 +-
 .../org/apache/spark/sql/hive/hiveUDFs.scala    | 20 ++++++--
 4 files changed, 77 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b767ceeb/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 19b2f24..597bf60 100644
--- 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -256,9 +256,9 @@ class HiveThriftBinaryServerSuite extends 
HiveThriftJdbcTest {
       { statement =>
 
         val queries = Seq(
-            s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291",
-            "SET hive.cli.print.header=true"
-            )
+          s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291",
+          "SET hive.cli.print.header=true"
+        )
 
         queries.map(statement.execute)
         val rs1 = statement.executeQuery(s"SET 
${SQLConf.SHUFFLE_PARTITIONS.key}")
@@ -458,6 +458,53 @@ class HiveThriftBinaryServerSuite extends 
HiveThriftJdbcTest {
       assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
     }
   }
+
+  test("SPARK-11595 ADD JAR with input path having URL scheme") {
+    withJdbcStatement { statement =>
+      statement.executeQuery("SET spark.sql.hive.thriftServer.async=true")
+
+      val jarPath = "../hive/src/test/resources/TestUDTF.jar"
+      val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
+
+      Seq(
+        s"ADD JAR $jarURL",
+        s"""CREATE TEMPORARY FUNCTION udtf_count2
+            |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+         """.stripMargin
+      ).foreach(statement.execute)
+
+      val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
+
+      assert(rs1.next())
+      assert(rs1.getString(1) === "Function: udtf_count2")
+
+      assert(rs1.next())
+      assertResult("Class: 
org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
+        rs1.getString(1)
+      }
+
+      assert(rs1.next())
+      assert(rs1.getString(1) === "Usage: To be added.")
+
+      val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
+
+      Seq(
+        s"CREATE TABLE test_udtf(key INT, value STRING)",
+        s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
+      ).foreach(statement.execute)
+
+      val rs2 = statement.executeQuery(
+        "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS 
cc")
+
+      assert(rs2.next())
+      assert(rs2.getInt(1) === 97)
+      assert(rs2.getInt(2) === 500)
+
+      assert(rs2.next())
+      assert(rs2.getInt(1) === 97)
+      assert(rs2.getInt(2) === 500)
+    }
+  }
 }
 
 class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/b767ceeb/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 5dc3d81..b3ba444 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -57,9 +57,11 @@ import org.apache.spark.util.Utils
 /**
  * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
  */
-private[hive] class HiveQLDialect extends ParserDialect {
+private[hive] class HiveQLDialect(sqlContext: HiveContext) extends 
ParserDialect {
   override def parse(sqlText: String): LogicalPlan = {
-    HiveQl.parseSql(sqlText)
+    sqlContext.executionHive.withHiveState {
+      HiveQl.parseSql(sqlText)
+    }
   }
 }
 
@@ -410,7 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) 
with Logging {
   // Note that HiveUDFs will be overridden by functions registered in this 
context.
   @transient
   override protected[sql] lazy val functionRegistry: FunctionRegistry =
-    new HiveFunctionRegistry(FunctionRegistry.builtin)
+    new HiveFunctionRegistry(FunctionRegistry.builtin, this)
 
   /* An analyzer that uses the Hive metastore. */
   @transient
@@ -517,10 +519,12 @@ class HiveContext(sc: SparkContext) extends 
SQLContext(sc) with Logging {
     }
   }
 
-  override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") 
{
-    classOf[HiveQLDialect].getCanonicalName
-  } else {
-    super.dialectClassName
+  protected[sql] override def getSQLDialect(): ParserDialect = {
+    if (conf.dialect == "hiveql") {
+      new HiveQLDialect(this)
+    } else {
+      super.getSQLDialect()
+    }
   }
 
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/b767ceeb/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index f49c97d..f45747a 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -245,7 +245,7 @@ private[hive] class ClientWrapper(
   /**
    * Runs `f` with ThreadLocal session state and classloaders configured for 
this version of hive.
    */
-  private def withHiveState[A](f: => A): A = retryLocked {
+  def withHiveState[A](f: => A): A = retryLocked {
     val original = Thread.currentThread().getContextClassLoader
     // Set the thread local metastore client to the client associated with 
this ClientWrapper.
     Hive.set(client)

http://git-wip-us.apache.org/repos/asf/spark/blob/b767ceeb/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 7182246..a510df6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -44,17 +44,23 @@ import org.apache.spark.sql.hive.HiveShim._
 import org.apache.spark.sql.types._
 
 
-private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
+private[hive] class HiveFunctionRegistry(
+    underlying: analysis.FunctionRegistry,
+    hiveContext: HiveContext)
   extends analysis.FunctionRegistry with HiveInspectors {
 
-  def getFunctionInfo(name: String): FunctionInfo = 
FunctionRegistry.getFunctionInfo(name)
+  def getFunctionInfo(name: String): FunctionInfo = {
+    hiveContext.executionHive.withHiveState {
+      FunctionRegistry.getFunctionInfo(name)
+    }
+  }
 
   override def lookupFunction(name: String, children: Seq[Expression]): 
Expression = {
     Try(underlying.lookupFunction(name, children)).getOrElse {
       // We only look it up to see if it exists, but do not include it in the 
HiveUDF since it is
       // not always serializable.
       val functionInfo: FunctionInfo =
-        Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
+        Option(getFunctionInfo(name.toLowerCase)).getOrElse(
           throw new AnalysisException(s"undefined function $name"))
 
       val functionClassName = functionInfo.getFunctionClass.getName
@@ -89,7 +95,7 @@ private[hive] class HiveFunctionRegistry(underlying: 
analysis.FunctionRegistry)
   override def lookupFunction(name: String): Option[ExpressionInfo] = {
     underlying.lookupFunction(name).orElse(
     Try {
-      val info = FunctionRegistry.getFunctionInfo(name)
+      val info = getFunctionInfo(name)
       val annotation = 
info.getFunctionClass.getAnnotation(classOf[Description])
       if (annotation != null) {
         Some(new ExpressionInfo(
@@ -98,7 +104,11 @@ private[hive] class HiveFunctionRegistry(underlying: 
analysis.FunctionRegistry)
           annotation.value(),
           annotation.extended()))
       } else {
-        None
+        Some(new ExpressionInfo(
+          info.getFunctionClass.getCanonicalName,
+          name,
+          null,
+          null))
       }
     }.getOrElse(None))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to