Repository: spark
Updated Branches:
  refs/heads/master fec67ed7e -> 9b8521e53


[SPARK-25068][SQL] Add exists function.

## What changes were proposed in this pull request?

This pr adds `exists` function which tests whether a predicate holds for one or 
more elements in the array.

```sql
> SELECT exists(array(1, 2, 3), x -> x % 2 == 0);
 true
```

## How was this patch tested?

Added tests.

Closes #22052 from ueshin/issues/SPARK-25068/exists.

Authored-by: Takuya UESHIN <ues...@databricks.com>
Signed-off-by: Xiao Li <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b8521e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b8521e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b8521e5

Branch: refs/heads/master
Commit: 9b8521e53e56a53b44c02366a99f8a8ee1307bbf
Parents: fec67ed
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Thu Aug 9 14:41:59 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Thu Aug 9 14:41:59 2018 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/higherOrderFunctions.scala      | 47 ++++++++++
 .../expressions/HigherOrderFunctionsSuite.scala | 37 ++++++++
 .../sql-tests/inputs/higher-order-functions.sql |  6 ++
 .../results/higher-order-functions.sql.out      | 18 ++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 96 ++++++++++++++++++++
 6 files changed, 205 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 390debd..15543c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -444,6 +444,7 @@ object FunctionRegistry {
     expression[ArrayTransform]("transform"),
     expression[MapFilter]("map_filter"),
     expression[ArrayFilter]("filter"),
+    expression[ArrayExists]("exists"),
     expression[ArrayAggregate]("aggregate"),
     CreateStruct.registryEntry,
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index d206733..7f8203a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -357,6 +357,53 @@ case class ArrayFilter(
 }
 
 /**
+ * Tests whether a predicate holds for one or more elements in the array.
+ */
+@ExpressionDescription(usage =
+  "_FUNC_(expr, pred) - Tests whether a predicate holds for one or more 
elements in the array.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0);
+       true
+  """,
+  since = "2.4.0")
+case class ArrayExists(
+    input: Expression,
+    function: Expression)
+  extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
+
+  override def nullable: Boolean = input.nullable
+
+  override def dataType: DataType = BooleanType
+
+  override def expectingFunctionType: AbstractDataType = BooleanType
+
+  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayExists = {
+    val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
+    copy(function = f(function, elem :: Nil))
+  }
+
+  @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), 
_) = function
+
+  override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
+    val arr = value.asInstanceOf[ArrayData]
+    val f = functionForEval
+    var exists = false
+    var i = 0
+    while (i < arr.numElements && !exists) {
+      elementVar.value.set(arr.get(i, elementVar.dataType))
+      if (f.eval(inputRow).asInstanceOf[Boolean]) {
+        exists = true
+      }
+      i += 1
+    }
+    exists
+  }
+
+  override def prettyName: String = "exists"
+}
+
+/**
  * Applies a binary operator to a start value and all elements in the array.
  */
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index f7e84b8..bc7d04c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -202,6 +202,43 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
       Seq(Seq(1, 3), null, Seq(5)))
   }
 
+  test("ArrayExists") {
+    def exists(expr: Expression, f: Expression => Expression): Expression = {
+      val at = expr.dataType.asInstanceOf[ArrayType]
+      ArrayExists(expr, createLambda(at.elementType, at.containsNull, f))
+    }
+
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
+    val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))
+    val ain = Literal.create(null, ArrayType(IntegerType, containsNull = 
false))
+
+    val isEven: Expression => Expression = x => x % 2 === 0
+    val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
+
+    checkEvaluation(exists(ai0, isEven), true)
+    checkEvaluation(exists(ai0, isNullOrOdd), true)
+    checkEvaluation(exists(ai1, isEven), false)
+    checkEvaluation(exists(ai1, isNullOrOdd), true)
+    checkEvaluation(exists(ain, isEven), null)
+    checkEvaluation(exists(ain, isNullOrOdd), null)
+
+    val as0 =
+      Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, 
containsNull = false))
+    val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, 
containsNull = true))
+    val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
+
+    val startsWithA: Expression => Expression = x => x.startsWith("a")
+
+    checkEvaluation(exists(as0, startsWithA), true)
+    checkEvaluation(exists(as1, startsWithA), false)
+    checkEvaluation(exists(asn, startsWithA), null)
+
+    val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
+      ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = 
true))
+    checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)),
+      Seq(true, null, true))
+  }
+
   test("ArrayAggregate") {
     val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
     val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))

http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index 136396d..ce1d0da 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -45,3 +45,9 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * 
val * size(z))) as
 
 -- Aggregate a null array
 select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) 
as v;
+
+-- Check for element existence
+select exists(ys, y -> y > 30) as v from nested;
+
+-- Check for element existence in a null array
+select exists(cast(null as array<int>), y -> y > 30) as v;

http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index e6f62f2..e18abce 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -145,3 +145,21 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a 
+ y + 1, a -> a + 2) a
 struct<v:int>
 -- !query 14 output
 NULL
+
+
+-- !query 15
+select exists(ys, y -> y > 30) as v from nested
+-- !query 15 schema
+struct<v:boolean>
+-- !query 15 output
+false
+true
+true
+
+
+-- !query 16
+select exists(cast(null as array<int>), y -> y > 30) as v
+-- !query 16 schema
+struct<v:boolean>
+-- !query 16 output
+NULL

http://git-wip-us.apache.org/repos/asf/spark/blob/9b8521e5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 24091f2..2c4238e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1996,6 +1996,102 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(ex3.getMessage.contains("data type mismatch: argument 2 requires 
boolean type"))
   }
 
+  test("exists function - array for primitive type not containing null") {
+    val df = Seq(
+      Seq(1, 9, 8, 7),
+      Seq(5, 9, 7),
+      Seq.empty,
+      null
+    ).toDF("i")
+
+    def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
+      checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
+        Seq(
+          Row(true),
+          Row(false),
+          Row(false),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArrayOfPrimitiveTypeNotContainsNull()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testArrayOfPrimitiveTypeNotContainsNull()
+  }
+
+  test("exists function - array for primitive type containing null") {
+    val df = Seq[Seq[Integer]](
+      Seq(1, 9, 8, null, 7),
+      Seq(5, null, null, 9, 7, null),
+      Seq.empty,
+      null
+    ).toDF("i")
+
+    def testArrayOfPrimitiveTypeContainsNull(): Unit = {
+      checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
+        Seq(
+          Row(true),
+          Row(false),
+          Row(false),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArrayOfPrimitiveTypeContainsNull()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testArrayOfPrimitiveTypeContainsNull()
+  }
+
+  test("exists function - array for non-primitive type") {
+    val df = Seq(
+      Seq("c", "a", "b"),
+      Seq("b", null, "c", null),
+      Seq.empty,
+      null
+    ).toDF("s")
+
+    def testNonPrimitiveType(): Unit = {
+      checkAnswer(df.selectExpr("exists(s, x -> x is null)"),
+        Seq(
+          Row(false),
+          Row(true),
+          Row(false),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testNonPrimitiveType()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testNonPrimitiveType()
+  }
+
+  test("exists function - invalid") {
+    val df = Seq(
+      (Seq("c", "a", "b"), 1),
+      (Seq("b", null, "c", null), 2),
+      (Seq.empty, 3),
+      (null, 4)
+    ).toDF("s", "i")
+
+    val ex1 = intercept[AnalysisException] {
+      df.selectExpr("exists(s, (x, y) -> x + y)")
+    }
+    assert(ex1.getMessage.contains("The number of lambda function arguments 
'2' does not match"))
+
+    val ex2 = intercept[AnalysisException] {
+      df.selectExpr("exists(i, x -> x)")
+    }
+    assert(ex2.getMessage.contains("data type mismatch: argument 1 requires 
array type"))
+
+    val ex3 = intercept[AnalysisException] {
+      df.selectExpr("exists(s, x -> x)")
+    }
+    assert(ex3.getMessage.contains("data type mismatch: argument 2 requires 
boolean type"))
+  }
+
   test("aggregate function - array for primitive type not containing null") {
     val df = Seq(
       Seq(1, 9, 8, 7),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to