Repository: spark
Updated Branches:
  refs/heads/master 8bb0df2c6 -> d5bec48b9


[SPARK-23919][SQL] Add array_position function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_position`. The behavior of the function is 
based on Presto's one.

The function returns the position of the first occurrence of the element in 
array x (or 0 if not found) using 1-based index as BigInt.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #21037 from kiszk/SPARK-23919.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d5bec48b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d5bec48b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d5bec48b

Branch: refs/heads/master
Commit: d5bec48b9cb225c19b43935c07b24090c51cacce
Parents: 8bb0df2
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Thu Apr 19 11:59:17 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Thu Apr 19 11:59:17 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 17 ++++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/collectionOperations.scala      | 56 ++++++++++++++++++++
 .../CollectionExpressionsSuite.scala            | 22 ++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 14 +++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 34 ++++++++++++
 6 files changed, 144 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d3bb0a5..36dcabc 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1845,6 +1845,23 @@ def array_contains(col, value):
     return Column(sc._jvm.functions.array_contains(_to_java_column(col), 
value))
 
 
+@since(2.4)
+def array_position(col, value):
+    """
+    Collection function: Locates the position of the first occurrence of the 
given value
+    in the given array. Returns null if either of the arguments are null.
+
+    .. note:: The position is not zero based, but 1 based index. Returns 0 if 
the given
+        value could not be found in the array.
+
+    >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
+    >>> df.select(array_position(df.data, "a")).collect()
+    [Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_position(_to_java_column(col), 
value))
+
+
 @since(1.4)
 def explode(col):
     """Returns a new row for each element in the given array or map.

http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 38c874a..74095fe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -402,6 +402,7 @@ object FunctionRegistry {
     // collection functions
     expression[CreateArray]("array"),
     expression[ArrayContains]("array_contains"),
+    expression[ArrayPosition]("array_position"),
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
     expression[MapKeys]("map_keys"),

http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 76b71f5..e6a05f5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends 
UnaryExpression with ImplicitCast
 
   override def prettyName: String = "array_max"
 }
+
+
+/**
+ * Returns the position of the first occurrence of element in the given array 
as long.
+ * Returns 0 if the given value could not be found in the array. Returns null 
if either of
+ * the arguments are null
+ *
+ * NOTE: that this is not zero based, but 1-based index. The first element in 
the array has
+ *       index 1.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array, element) - Returns the (1-based) index of the first element 
of the array as long.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(3, 2, 1), 1);
+       3
+  """,
+  since = "2.4.0")
+case class ArrayPosition(left: Expression, right: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = LongType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
+
+  override def nullSafeEval(arr: Any, value: Any): Any = {
+    arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+      if (v == value) {
+        return (i + 1).toLong
+      }
+    )
+    0L
+  }
+
+  override def prettyName: String = "array_position"
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, (arr, value) => {
+      val pos = ctx.freshName("arrayPosition")
+      val i = ctx.freshName("i")
+      val getValue = CodeGenerator.getValue(arr, right.dataType, i)
+      s"""
+         |int $pos = 0;
+         |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+         |  if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, 
getValue)}) {
+         |    $pos = $i + 1;
+         |    break;
+         |  }
+         |}
+         |${ev.value} = (long) $pos;
+       """.stripMargin
+    })
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 517639d..916cd3b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(Reverse(as7), null)
     checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
   }
+
+  test("Array Position") {
+    val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+    val a2 = Literal.create(Seq(null), ArrayType(LongType))
+    val a3 = Literal.create(null, ArrayType(StringType))
+
+    checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
+    checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
+    checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
+    checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
+
+    checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
+    checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
+    checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
+
+    checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
+    checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
+
+    checkEvaluation(ArrayPosition(a3, Literal("")), null)
+    checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index a55a800..3a09ec4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3039,6 +3039,20 @@ object functions {
   }
 
   /**
+   * Locates the position of the first occurrence of the value in the given 
array as long.
+   * Returns null if either of the arguments are null.
+   *
+   * @note The position is not zero based, but 1 based index. Returns 0 if 
value
+   * could not be found in array.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_position(column: Column, value: Any): Column = withExpr {
+    ArrayPosition(column.expr, Literal(value))
+  }
+
+  /**
    * Creates a new row for each element in the given array or map column.
    *
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/d5bec48b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 74c42f2..13161e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("array position function") {
+    val df = Seq(
+      (Seq[Int](1, 2), "x"),
+      (Seq[Int](), "x")
+    ).toDF("a", "b")
+
+    checkAnswer(
+      df.select(array_position(df("a"), 1)),
+      Seq(Row(1L), Row(0L))
+    )
+    checkAnswer(
+      df.selectExpr("array_position(a, 1)"),
+      Seq(Row(1L), Row(0L))
+    )
+
+    checkAnswer(
+      df.select(array_position(df("a"), null)),
+      Seq(Row(null), Row(null))
+    )
+    checkAnswer(
+      df.selectExpr("array_position(a, null)"),
+      Seq(Row(null), Row(null))
+    )
+
+    checkAnswer(
+      df.selectExpr("array_position(array(array(1), null)[0], 1)"),
+      Seq(Row(1L), Row(1L))
+    )
+    checkAnswer(
+      df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
+      Seq(Row(1L), Row(1L))
+    )
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to