Repository: spark
Updated Branches:
  refs/heads/master 0c20ce69f -> ca7a6cdff


[SPARK-5550] [SQL] Support the case insensitive for UDF

SQL in HiveContext, should be case insensitive, however, the following query 
will fail.

```scala
udf.register("random0", ()  => { Math.random()})
assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
```

Author: Cheng Hao <hao.ch...@intel.com>

Closes #4326 from chenghao-intel/udf_case_sensitive and squashes the following 
commits:

485cf66 [Cheng Hao] Support the case insensitive for UDF


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca7a6cdf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca7a6cdf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca7a6cdf

Branch: refs/heads/master
Commit: ca7a6cdff004eb4605fd223e127b4a46a0a214e7
Parents: 0c20ce6
Author: Cheng Hao <hao.ch...@intel.com>
Authored: Tue Feb 3 12:12:26 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Feb 3 12:12:26 2015 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    | 36 +++++++++++++++++---
 .../scala/org/apache/spark/sql/SQLContext.scala |  2 +-
 .../org/apache/spark/sql/hive/HiveContext.scala |  4 ++-
 .../org/apache/spark/sql/hive/UDFSuite.scala    | 36 ++++++++++++++++++++
 4 files changed, 72 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca7a6cdf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 760c49f..9f334f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -27,23 +27,25 @@ trait FunctionRegistry {
   def registerFunction(name: String, builder: FunctionBuilder): Unit
 
   def lookupFunction(name: String, children: Seq[Expression]): Expression
+
+  def caseSensitive: Boolean
 }
 
 trait OverrideFunctionRegistry extends FunctionRegistry {
 
-  val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
+  val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
 
   def registerFunction(name: String, builder: FunctionBuilder) = {
     functionBuilders.put(name, builder)
   }
 
   abstract override def lookupFunction(name: String, children: 
Seq[Expression]): Expression = {
-    
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children))
+    
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,
 children))
   }
 }
 
-class SimpleFunctionRegistry extends FunctionRegistry {
-  val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
+class SimpleFunctionRegistry(val caseSensitive: Boolean) extends 
FunctionRegistry {
+  val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
 
   def registerFunction(name: String, builder: FunctionBuilder) = {
     functionBuilders.put(name, builder)
@@ -64,4 +66,30 @@ object EmptyFunctionRegistry extends FunctionRegistry {
   def lookupFunction(name: String, children: Seq[Expression]): Expression = {
     throw new UnsupportedOperationException
   }
+
+  def caseSensitive: Boolean = ???
+}
+
+/**
+ * Build a map with String type of key, and it also supports either key case
+ * sensitive or insensitive.
+ * TODO move this into util folder?
+ */
+object StringKeyHashMap {
+  def apply[T](caseSensitive: Boolean) = caseSensitive match {
+    case false => new StringKeyHashMap[T](_.toLowerCase)
+    case true => new StringKeyHashMap[T](identity)
+  }
+}
+
+class StringKeyHashMap[T](normalizer: (String) => String) {
+  private val base = new collection.mutable.HashMap[String, T]()
+
+  def apply(key: String): T = base(normalizer(key))
+
+  def get(key: String): Option[T] = base.get(normalizer(key))
+  def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
+  def remove(key: String): Option[T] = base.remove(normalizer(key))
+  def iterator: Iterator[(String, T)] = base.toIterator
 }
+

http://git-wip-us.apache.org/repos/asf/spark/blob/ca7a6cdf/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a741d00..2697e78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -87,7 +87,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
   protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
 
   @transient
-  protected[sql] lazy val functionRegistry: FunctionRegistry = new 
SimpleFunctionRegistry
+  protected[sql] lazy val functionRegistry: FunctionRegistry = new 
SimpleFunctionRegistry(true)
 
   @transient
   protected[sql] lazy val analyzer: Analyzer =

http://git-wip-us.apache.org/repos/asf/spark/blob/ca7a6cdf/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index f6d9027..50f266a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -311,7 +311,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
   // Note that HiveUDFs will be overridden by functions registered in this 
context.
   @transient
   override protected[sql] lazy val functionRegistry =
-    new HiveFunctionRegistry with OverrideFunctionRegistry
+    new HiveFunctionRegistry with OverrideFunctionRegistry {
+      def caseSensitive = false
+    }
 
   /* An analyzer that uses the Hive metastore. */
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/ca7a6cdf/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
new file mode 100644
index 0000000..85b6bc9
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+/* Implicits */
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.hive.test.TestHive._
+
+case class FunctionResult(f1: String, f2: String)
+
+class UDFSuite extends QueryTest {
+  test("UDF case insensitive") {
+    udf.register("random0", () => { Math.random()})
+    udf.register("RANDOM1", () => { Math.random()})
+    udf.register("strlenScala", (_: String).length + (_:Int))
+    assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+    assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+    assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 
1").head().getInt(0) === 5)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to