Repository: spark
Updated Branches:
  refs/heads/branch-2.0 e32c29d86 -> 565e18cf7


[SPARK-16286][SQL] Implement stack table generating function

This PR implements `stack` table generating function.

Pass the Jenkins tests including new testcases.

Author: Dongjoon Hyun <dongj...@apache.org>

Closes #14033 from dongjoon-hyun/SPARK-16286.

(cherry picked from commit d0d28507cacfca5919dbfb4269892d58b62e8662)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/565e18cf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/565e18cf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/565e18cf

Branch: refs/heads/branch-2.0
Commit: 565e18cf7670231b1fa9db84f907654da79e6cef
Parents: e32c29d
Author: Dongjoon Hyun <dongj...@apache.org>
Authored: Wed Jul 6 10:54:43 2016 +0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Jul 7 21:09:09 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../sql/catalyst/expressions/generators.scala   | 53 ++++++++++++++++++++
 .../expressions/GeneratorExpressionSuite.scala  | 18 +++++++
 .../spark/sql/GeneratorFunctionSuite.scala      | 53 ++++++++++++++++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |  2 +-
 5 files changed, 126 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/565e18cf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 021bec7..f6ebcae 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -182,6 +182,7 @@ object FunctionRegistry {
     expression[PosExplode]("posexplode"),
     expression[Rand]("rand"),
     expression[Randn]("randn"),
+    expression[Stack]("stack"),
     expression[CreateStruct]("struct"),
     expression[CaseWhen]("when"),
 

http://git-wip-us.apache.org/repos/asf/spark/blob/565e18cf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 99b97c8..9d5c856 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -94,6 +94,59 @@ case class UserDefinedGenerator(
 }
 
 /**
+ * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be 
constant.
+ * {{{
+ *   SELECT stack(2, 1, 2, 3) ->
+ *   1      2
+ *   3      NULL
+ * }}}
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.",
+  extended = "> SELECT _FUNC_(2, 1, 2, 3);\n  [1,2]\n  [3,null]")
+case class Stack(children: Seq[Expression])
+    extends Expression with Generator with CodegenFallback {
+
+  private lazy val numRows = children.head.eval().asInstanceOf[Int]
+  private lazy val numFields = Math.ceil((children.length - 1.0) / 
numRows).toInt
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.length <= 1) {
+      TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 
arguments.")
+    } else if (children.head.dataType != IntegerType || 
!children.head.foldable || numRows < 1) {
+      TypeCheckResult.TypeCheckFailure("The number of rows must be a positive 
constant integer.")
+    } else {
+      for (i <- 1 until children.length) {
+        val j = (i - 1) % numFields
+        if (children(i).dataType != elementSchema.fields(j).dataType) {
+          return TypeCheckResult.TypeCheckFailure(
+            s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " +
+              s"Argument $i (${children(i).dataType})")
+        }
+      }
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def elementSchema: StructType =
+    StructType(children.tail.take(numFields).zipWithIndex.map {
+      case (e, index) => StructField(s"col$index", e.dataType)
+    })
+
+  override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+    val values = children.tail.map(_.eval(input)).toArray
+    for (row <- 0 until numRows) yield {
+      val fields = new Array[Any](numFields)
+      for (col <- 0 until numFields) {
+        val index = row * numFields + col
+        fields.update(col, if (index < values.length) values(index) else null)
+      }
+      InternalRow(fields: _*)
+    }
+  }
+}
+
+/**
  * A base class for Explode and PosExplode
  */
 abstract class ExplodeBase(child: Expression, position: Boolean)

http://git-wip-us.apache.org/repos/asf/spark/blob/565e18cf/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
index e79f89b..e29dfa4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -63,4 +63,22 @@ class GeneratorExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       ))),
       correct_answer)
   }
+
+  test("stack") {
+    checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1)))
+    checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2)))
+    checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), 
create_row(2)))
+    checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), 
create_row(3, null)))
+    checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 
3).map(create_row(_)))
+    checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, 
null).map(create_row(_)))
+
+    checkTuple(
+      Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))),
+      Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, 
"c")))
+
+    assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure)
+    assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure)
+    assert(Stack(Seq(Literal(1), Literal(1), 
Literal(1.0))).checkInputDataTypes().isSuccess)
+    assert(Stack(Seq(Literal(2), Literal(1), 
Literal(1.0))).checkInputDataTypes().isFailure)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/565e18cf/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index d8a0aa4..aedc0a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -23,6 +23,59 @@ import org.apache.spark.sql.test.SharedSQLContext
 class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
+  test("stack") {
+    val df = spark.range(1)
+
+    // Empty DataFrame suppress the result generation
+    checkAnswer(spark.emptyDataFrame.selectExpr("stack(1, 1, 2, 3)"), Nil)
+
+    // Rows & columns
+    checkAnswer(df.selectExpr("stack(1, 1, 2, 3)"), Row(1, 2, 3) :: Nil)
+    checkAnswer(df.selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) 
:: Nil)
+    checkAnswer(df.selectExpr("stack(3, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) 
:: Nil)
+    checkAnswer(df.selectExpr("stack(4, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) 
:: Row(null) :: Nil)
+
+    // Various column types
+    checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 
'c')"),
+      Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil)
+
+    // Repeat generation at every input row
+    checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"),
+      Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
+
+    // The first argument must be a positive constant integer.
+    val m = intercept[AnalysisException] {
+      df.selectExpr("stack(1.1, 1, 2, 3)")
+    }.getMessage
+    assert(m.contains("The number of rows must be a positive constant 
integer."))
+    val m2 = intercept[AnalysisException] {
+      df.selectExpr("stack(-1, 1, 2, 3)")
+    }.getMessage
+    assert(m2.contains("The number of rows must be a positive constant 
integer."))
+
+    // The data for the same column should have the same type.
+    val m3 = intercept[AnalysisException] {
+      df.selectExpr("stack(2, 1, '2.2')")
+    }.getMessage
+    assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != 
Argument 2 (StringType)"))
+
+    // stack on column data
+    val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c")
+    checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) 
:: Nil)
+
+    val m4 = intercept[AnalysisException] {
+      df2.selectExpr("stack(n, a, b, c)")
+    }.getMessage
+    assert(m4.contains("The number of rows must be a positive constant 
integer."))
+
+    val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b")
+    val m5 = intercept[AnalysisException] {
+      df3.selectExpr("stack(2, a, b)")
+    }.getMessage
+    assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != 
Argument 2 (DoubleType)"))
+
+  }
+
   test("single explode") {
     val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/565e18cf/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index ebb6711..fdc4c18 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
   // str_to_map, windowingtablefunction.
   private val hiveFunctions = Seq(
     "hash", "java_method", "histogram_numeric",
-    "parse_url", "percentile", "percentile_approx", "reflect", "sentences", 
"stack", "str_to_map",
+    "parse_url", "percentile", "percentile_approx", "reflect", "sentences", 
"str_to_map",
     "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
     "xpath_number", "xpath_short", "xpath_string"
   )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to