Repository: spark
Updated Branches:
  refs/heads/master 142df4834 -> f5fef6914


[SPARK-16281][SQL] Implement parse_url SQL function

## What changes were proposed in this pull request?

This PR adds parse_url SQL functions in order to remove Hive fallback.

A new implementation of #13999

## How was this patch tested?

Pass the exist tests including new testcases.

Author: wujian <jan.chou...@gmail.com>

Closes #14008 from janplus/SPARK-16281.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f5fef691
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f5fef691
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f5fef691

Branch: refs/heads/master
Commit: f5fef69143b2a83bb8b168b7417e92659af0c72c
Parents: 142df48
Author: wujian <jan.chou...@gmail.com>
Authored: Fri Jul 8 14:38:05 2016 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Jul 8 14:38:05 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/stringExpressions.scala         | 150 +++++++++++++++++++
 .../expressions/StringExpressionsSuite.scala    |  51 +++++++
 .../apache/spark/sql/StringFunctionsSuite.scala |  15 ++
 .../spark/sql/hive/HiveSessionCatalog.scala     |   2 +-
 5 files changed, 218 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f5fef691/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 842c9c6..c8bbbf8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -288,6 +288,7 @@ object FunctionRegistry {
     expression[StringLPad]("lpad"),
     expression[StringTrimLeft]("ltrim"),
     expression[JsonTuple]("json_tuple"),
+    expression[ParseUrl]("parse_url"),
     expression[FormatString]("printf"),
     expression[RegExpExtract]("regexp_extract"),
     expression[RegExpReplace]("regexp_replace"),

http://git-wip-us.apache.org/repos/asf/spark/blob/f5fef691/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 894e12d..61549c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.net.{MalformedURLException, URL}
 import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
 import java.util.{HashMap, Locale, Map => JMap}
+import java.util.regex.Pattern
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -654,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, 
pad: Expression)
   override def prettyName: String = "rpad"
 }
 
+object ParseUrl {
+  private val HOST = UTF8String.fromString("HOST")
+  private val PATH = UTF8String.fromString("PATH")
+  private val QUERY = UTF8String.fromString("QUERY")
+  private val REF = UTF8String.fromString("REF")
+  private val PROTOCOL = UTF8String.fromString("PROTOCOL")
+  private val FILE = UTF8String.fromString("FILE")
+  private val AUTHORITY = UTF8String.fromString("AUTHORITY")
+  private val USERINFO = UTF8String.fromString("USERINFO")
+  private val REGEXPREFIX = "(&|^)"
+  private val REGEXSUBFIX = "=([^&]*)"
+}
+
+/**
+ * Extracts a part from a URL
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
+  extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, 
USERINFO.
+    Key specifies which query to extract.
+    Examples:
+      > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST')
+      'spark.apache.org'
+      > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY')
+      'query=1'
+      > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query')
+      '1'""")
+case class ParseUrl(children: Seq[Expression])
+  extends Expression with ExpectsInputTypes with CodegenFallback {
+
+  override def nullable: Boolean = true
+  override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
+  override def dataType: DataType = StringType
+  override def prettyName: String = "parse_url"
+
+  // If the url is a constant, cache the URL object so that we don't need to 
convert url
+  // from UTF8String to String to URL for every row.
+  @transient private lazy val cachedUrl = children(0) match {
+    case Literal(url: UTF8String, _) if url ne null => getUrl(url)
+    case _ => null
+  }
+
+  // If the key is a constant, cache the Pattern object so that we don't need 
to convert key
+  // from UTF8String to String to StringBuilder to String to Pattern for every 
row.
+  @transient private lazy val cachedPattern = children(2) match {
+    case Literal(key: UTF8String, _) if key ne null => getPattern(key)
+    case _ => null
+  }
+
+  // If the partToExtract is a constant, cache the Extract part function so 
that we don't need
+  // to check the partToExtract for every row.
+  @transient private lazy val cachedExtractPartFunc = children(1) match {
+    case Literal(part: UTF8String, _) => getExtractPartFunc(part)
+    case _ => null
+  }
+
+  import ParseUrl._
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.size > 3 || children.size < 2) {
+      TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or 
three arguments")
+    } else {
+      super[ExpectsInputTypes].checkInputDataTypes()
+    }
+  }
+
+  private def getPattern(key: UTF8String): Pattern = {
+    Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
+  }
+
+  private def getUrl(url: UTF8String): URL = {
+    try {
+      new URL(url.toString)
+    } catch {
+      case e: MalformedURLException => null
+    }
+  }
+
+  private def getExtractPartFunc(partToExtract: UTF8String): URL => String = {
+    partToExtract match {
+      case HOST => _.getHost
+      case PATH => _.getPath
+      case QUERY => _.getQuery
+      case REF => _.getRef
+      case PROTOCOL => _.getProtocol
+      case FILE => _.getFile
+      case AUTHORITY => _.getAuthority
+      case USERINFO => _.getUserInfo
+      case _ => (url: URL) => null
+    }
+  }
+
+  private def extractValueFromQuery(query: UTF8String, pattern: Pattern): 
UTF8String = {
+    val m = pattern.matcher(query.toString)
+    if (m.find()) {
+      UTF8String.fromString(m.group(2))
+    } else {
+      null
+    }
+  }
+
+  private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String 
= {
+    if (cachedExtractPartFunc ne null) {
+      UTF8String.fromString(cachedExtractPartFunc.apply(url))
+    } else {
+      UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
+    }
+  }
+
+  private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): 
UTF8String = {
+    if (cachedUrl ne null) {
+      extractFromUrl(cachedUrl, partToExtract)
+    } else {
+      val currentUrl = getUrl(url)
+      if (currentUrl ne null) {
+        extractFromUrl(currentUrl, partToExtract)
+      } else {
+        null
+      }
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
+    if (evaluated.contains(null)) return null
+    if (evaluated.size == 2) {
+      parseUrlWithoutKey(evaluated(0), evaluated(1))
+    } else {
+      // 3-arg, i.e. QUERY with key
+      assert(evaluated.size == 3)
+      if (evaluated(1) != QUERY) {
+        return null
+      }
+
+      val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
+      if (query eq null) {
+        return null
+      }
+
+      if (cachedPattern ne null) {
+        extractValueFromQuery(query, cachedPattern)
+      } else {
+        extractValueFromQuery(query, getPattern(evaluated(2)))
+      }
+    }
+  }
+}
+
 /**
  * Returns the input formatted according do printf-style format strings
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/f5fef691/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 256ce85..8f7b104 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -726,6 +726,57 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
   }
 
+  test("ParseUrl") {
+    def checkParseUrl(expected: String, urlStr: String, partToExtract: 
String): Unit = {
+      checkEvaluation(
+        ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
+    }
+    def checkParseUrlWithKey(
+        expected: String,
+        urlStr: String,
+        partToExtract: String,
+        key: String): Unit = {
+      checkEvaluation(
+        ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), 
expected)
+    }
+
+    checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1";, 
"HOST")
+    checkParseUrl("/path", "http://spark.apache.org/path?query=1";, "PATH")
+    checkParseUrl("query=1", "http://spark.apache.org/path?query=1";, "QUERY")
+    checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref";, "REF")
+    checkParseUrl("http", "http://spark.apache.org/path?query=1";, "PROTOCOL")
+    checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1";, 
"FILE")
+    checkParseUrl("spark.apache.org:8080", 
"http://spark.apache.org:8080/path?query=1";, "AUTHORITY")
+    checkParseUrl("userinfo", "http://useri...@spark.apache.org/path?query=1";, 
"USERINFO")
+    checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1";, "QUERY", 
"query")
+
+    // Null checking
+    checkParseUrl(null, null, "HOST")
+    checkParseUrl(null, "http://spark.apache.org/path?query=1";, null)
+    checkParseUrl(null, null, null)
+    checkParseUrl(null, "test", "HOST")
+    checkParseUrl(null, "http://spark.apache.org/path?query=1";, "NO")
+    checkParseUrl(null, "http://spark.apache.org/path?query=1";, "USERINFO")
+    checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1";, "HOST", 
"query")
+    checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1";, 
"QUERY", "quer")
+    checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1";, 
"QUERY", null)
+    checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1";, 
"QUERY", "")
+
+    // exceptional cases
+    intercept[java.util.regex.PatternSyntaxException] {
+      evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?";),
+        Literal("QUERY"), Literal("???"))))
+    }
+
+    // arguments checking
+    assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
+    assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), 
Literal("4")))
+      .checkInputDataTypes().isFailure)
+    assert(ParseUrl(Seq(Literal("1"), 
Literal(2))).checkInputDataTypes().isFailure)
+    assert(ParseUrl(Seq(Literal(1), 
Literal("2"))).checkInputDataTypes().isFailure)
+    assert(ParseUrl(Seq(Literal("1"), Literal("2"), 
Literal(3))).checkInputDataTypes().isFailure)
+  }
+
   test("Sentences") {
     val nullString = Literal.create(null, StringType)
     checkEvaluation(Sentences(nullString, nullString, nullString), null)

http://git-wip-us.apache.org/repos/asf/spark/blob/f5fef691/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 044ac22..f509551 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -228,6 +228,21 @@ class StringFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Row("???hi", "hi???", "h", "h"))
   }
 
+  test("string parse_url function") {
+    val df = Seq[String](("http://useri...@spark.apache.org/path?query=1#Ref";))
+      .toDF("url")
+
+    checkAnswer(
+      df.selectExpr(
+        "parse_url(url, 'HOST')", "parse_url(url, 'PATH')",
+        "parse_url(url, 'QUERY')", "parse_url(url, 'REF')",
+        "parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')",
+        "parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')",
+        "parse_url(url, 'QUERY', 'query')"),
+      Row("spark.apache.org", "/path", "query=1", "Ref",
+        "http", "/path?query=1", "useri...@spark.apache.org", "userinfo", "1"))
+  }
+
   test("string repeat function") {
     val df = Seq(("hi", 2)).toDF("a", "b")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f5fef691/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 6f05f0f..9c7f461 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
   // str_to_map, windowingtablefunction.
   private val hiveFunctions = Seq(
     "hash", "java_method", "histogram_numeric",
-    "parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
+    "percentile", "percentile_approx", "reflect", "str_to_map",
     "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
     "xpath_number", "xpath_short", "xpath_string"
   )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to