Repository: spark
Updated Branches:
  refs/heads/branch-2.0 176af17a7 -> ea684b69c


[SPARK-17069] Expose spark.range() as table-valued function in SQL

This adds analyzer rules for resolving table-valued functions, and adds one 
builtin implementation for range(). The arguments for range() are the same as 
those of `spark.range()`.

Unit tests.

cc hvanhovell

Author: Eric Liang <e...@databricks.com>

Closes #14656 from ericl/sc-4309.

(cherry picked from commit 412dba63b511474a6db3c43c8618d803e604bc6b)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ea684b69
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ea684b69
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ea684b69

Branch: refs/heads/branch-2.0
Commit: ea684b69cd6934bc093f4a5a8b0d8470e92157cd
Parents: 176af17
Author: Eric Liang <e...@databricks.com>
Authored: Thu Aug 18 13:33:55 2016 +0200
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Aug 18 18:36:50 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/parser/SqlBase.g4 |   1 +
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   1 +
 .../analysis/ResolveTableValuedFunctions.scala  | 132 +++++++++++++++++++
 .../sql/catalyst/analysis/unresolved.scala      |  11 ++
 .../spark/sql/catalyst/parser/AstBuilder.scala  |   8 ++
 .../sql/catalyst/parser/PlanParserSuite.scala   |   8 +-
 .../sql-tests/inputs/table-valued-functions.sql |  20 +++
 .../results/table-valued-functions.sql.out      |  87 ++++++++++++
 8 files changed, 267 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index aca7282..51f3804 100644
--- 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -426,6 +426,7 @@ relationPrimary
     | '(' queryNoWith ')' sample? (AS? strictIdentifier)?           
#aliasedQuery
     | '(' relation ')' sample? (AS? strictIdentifier)?              
#aliasedRelation
     | inlineTable                                                   
#inlineTableDefault2
+    | identifier '(' (expression (',' expression)*)? ')'            
#tableValuedFunction
     ;
 
 inlineTable

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 57c3d9a..e0b8166 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,6 +86,7 @@ class Analyzer(
       WindowsSubstitution,
       EliminateUnions),
     Batch("Resolution", fixedPoint,
+      ResolveTableValuedFunctions ::
       ResolveRelations ::
       ResolveReferences ::
       ResolveDeserializer ::

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
new file mode 100644
index 0000000..7fdf7fa
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
+
+/**
+ * Rule that resolves table-valued function references.
+ */
+object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
+  private lazy val defaultParallelism =
+    SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
+
+  /**
+   * List of argument names and their types, used to declare a function.
+   */
+  private case class ArgumentList(args: (String, DataType)*) {
+    /**
+     * Try to cast the expressions to satisfy the expected types of this 
argument list. If there
+     * are any types that cannot be casted, then None is returned.
+     */
+    def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+      if (args.length == values.length) {
+        val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+          TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
+        }
+        if (casted.forall(_.isDefined)) {
+          return Some(casted.map(_.get))
+        }
+      }
+      None
+    }
+
+    override def toString: String = {
+      args.map { a =>
+        s"${a._1}: ${a._2.typeName}"
+      }.mkString(", ")
+    }
+  }
+
+  /**
+   * A TVF maps argument lists to resolver functions that accept those 
arguments. Using a map
+   * here allows for function overloading.
+   */
+  private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+  /**
+   * TVF builder.
+   */
+  private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], 
LogicalPlan])
+      : (ArgumentList, Seq[Any] => LogicalPlan) = {
+    (ArgumentList(args: _*),
+     pf orElse {
+       case args =>
+         throw new IllegalArgumentException(
+           "Invalid arguments for resolved function: " + args.mkString(", "))
+     })
+  }
+
+  /**
+   * Internal registry of table-valued functions.
+   */
+  private val builtinFunctions: Map[String, TVF] = Map(
+    "range" -> Map(
+      /* range(end) */
+      tvf("end" -> LongType) { case Seq(end: Long) =>
+        Range(0, end, 1, defaultParallelism)
+      },
+
+      /* range(start, end) */
+      tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: 
Long) =>
+        Range(start, end, 1, defaultParallelism)
+      },
+
+      /* range(start, end, step) */
+      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
+        case Seq(start: Long, end: Long, step: Long) =>
+          Range(start, end, step, defaultParallelism)
+      },
+
+      /* range(start, end, step, numPartitions) */
+      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
+          "numPartitions" -> IntegerType) {
+        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
+          Range(start, end, step, numPartitions)
+      })
+  )
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) 
=>
+      builtinFunctions.get(u.functionName) match {
+        case Some(tvf) =>
+          val resolved = tvf.flatMap { case (argList, resolver) =>
+            argList.implicitCast(u.functionArgs) match {
+              case Some(casted) =>
+                Some(resolver(casted.map(_.eval())))
+              case _ =>
+                None
+            }
+          }
+          resolved.headOption.getOrElse {
+            val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", 
")
+            u.failAnalysis(
+              s"""error: table-valued function ${u.functionName} with 
alternatives:
+                |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" 
($x)").mkString("\n")}
+                |cannot be applied to: (${argTypes})""".stripMargin)
+          }
+        case _ =>
+          u.failAnalysis(s"could not resolve `${u.functionName}` to a 
table-valued function")
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 609089a..a66a551 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,6 +50,17 @@ case class UnresolvedRelation(
 }
 
 /**
+ * Holds a table-valued function call that has yet to be resolved.
+ */
+case class UnresolvedTableValuedFunction(
+    functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+
+  override def output: Seq[Attribute] = Nil
+
+  override lazy val resolved = false
+}
+
+/**
  * Holds the name of an attribute that has yet to be resolved.
  */
 case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with 
Unevaluable {

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index aee8eb1..1f2cd4c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -653,6 +653,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
Logging {
   }
 
   /**
+   * Create a table-valued function call with arguments, e.g. range(1000)
+   */
+  override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+      : LogicalPlan = withOrigin(ctx) {
+    UnresolvedTableValuedFunction(ctx.identifier.getText, 
ctx.expression.asScala.map(expression))
+  }
+
+  /**
    * Create an inline table (a virtual table in Hive parlance).
    */
   override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = 
withOrigin(ctx) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 456948d..05bd0cd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, 
UnresolvedTableValuedFunction}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -423,6 +423,12 @@ class PlanParserSuite extends PlanTest {
     assertEqual("table d.t", table("d", "t"))
   }
 
+  test("table valued function") {
+    assertEqual(
+      "select * from range(2)",
+      UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+  }
+
   test("inline table") {
     assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
       Seq('col1.int),

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
new file mode 100644
index 0000000..2e6dcd5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -0,0 +1,20 @@
+-- unresolved function
+select * from dummy(3);
+
+-- range call with end
+select * from range(6 + cos(3));
+
+-- range call with start and end
+select * from range(5, 10);
+
+-- range call with step
+select * from range(0, 10, 2);
+
+-- range call with numPartitions
+select * from range(0, 10, 1, 200);
+
+-- range call error
+select * from range(1, 1, 1, 1, 1);
+
+-- range call with null
+select * from range(1, null);

http://git-wip-us.apache.org/repos/asf/spark/blob/ea684b69/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
new file mode 100644
index 0000000..d769bce
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -0,0 +1,87 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+select * from dummy(3)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+org.apache.spark.sql.AnalysisException
+could not resolve `dummy` to a table-valued function; line 1 pos 14
+
+
+-- !query 1
+select * from range(6 + cos(3))
+-- !query 1 schema
+struct<id:bigint>
+-- !query 1 output
+0
+1
+2
+3
+4
+
+
+-- !query 2
+select * from range(5, 10)
+-- !query 2 schema
+struct<id:bigint>
+-- !query 2 output
+5
+6
+7
+8
+9
+
+
+-- !query 3
+select * from range(0, 10, 2)
+-- !query 3 schema
+struct<id:bigint>
+-- !query 3 output
+0
+2
+4
+6
+8
+
+
+-- !query 4
+select * from range(0, 10, 1, 200)
+-- !query 4 schema
+struct<id:bigint>
+-- !query 4 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 5
+select * from range(1, 1, 1, 1, 1)
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+error: table-valued function range with alternatives:
+ (end: long)
+ (start: long, end: long)
+ (start: long, end: long, step: long)
+ (start: long, end: long, step: long, numPartitions: integer)
+cannot be applied to: (integer, integer, integer, integer, integer); line 1 
pos 14
+
+
+-- !query 6
+select * from range(1, null)
+-- !query 6 schema
+struct<>
+-- !query 6 output
+java.lang.IllegalArgumentException
+Invalid arguments for resolved function: 1, null


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to