viirya commented on a change in pull request #32735:
URL: https://github.com/apache/spark/pull/32735#discussion_r643496109



##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
##########
@@ -735,4 +736,100 @@ class HigherOrderFunctionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, 
plusOne), StringType)),
       Seq("[2, 3, 4]", null, "[5, 6]"))
   }
+
+  test("semanticEquals between ArrayAggregate") {
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
+    val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))
+    val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, 
containsNull = false))
+    val ain = Literal.create(null, ArrayType(IntegerType, containsNull = 
false))
+
+    val agg1_1 = aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10)
+    val agg1_2 = aggregate(ai0, 0, (acc, elem) => acc + elem, acc => 
Literal(10) * acc)
+    assert(agg1_1.semanticEquals(agg1_2))
+
+    val agg2_1 = aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc 
=> acc * 10)
+    val agg2_2 = aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc 
=> Literal(10) * acc)
+    assert(agg2_1.semanticEquals(agg2_2))
+
+    val agg3_1 = aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10)
+    val agg3_2 = aggregate(ai2, 0, (acc, elem) => acc + elem, acc => 
Literal(10) * acc)
+    assert(agg3_1.semanticEquals(agg3_2))
+
+    val agg4_1 = aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10)
+    val agg4_2 = aggregate(ain, 0, (acc, elem) => acc + elem, acc => 
Literal(10) * acc)
+    assert(agg4_1.semanticEquals(agg4_2))
+
+    assert(!agg1_1.semanticEquals(agg2_1))
+    assert(!agg1_1.semanticEquals(agg3_1))
+    assert(!agg1_1.semanticEquals(agg4_1))
+  }
+
+  test("semanticEquals between ArrayTransform") {
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
+    val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))
+
+    val plusOne_1: Expression => Expression = x => x + 1
+    val plusOne_2: Expression => Expression = x => Literal(1) + x
+    val plusIndex_1: (Expression, Expression) => Expression = (x, i) => x + i
+    val plusIndex_2: (Expression, Expression) => Expression = (x, i) => i + x
+
+    val trans1_1 = transform(ai0, plusOne_1)
+    val trans1_2 = transform(ai0, plusOne_2)
+    val trans1_3 = transform(ai1, plusOne_1)
+    assert(trans1_1.semanticEquals(trans1_2))
+    assert(!trans1_1.semanticEquals(trans1_3))
+
+    val trans2_1 = transform(ai0, plusIndex_1)
+    val trans2_2 = transform(ai0, plusIndex_2)
+    val trans2_3 = transform(ai1, plusIndex_1)
+    assert(trans2_1.semanticEquals(trans2_2))
+    assert(!trans2_1.semanticEquals(trans2_3))
+
+    val trans3_1 = transform(transform(ai0, plusIndex_1), plusOne_1)
+    val trans3_2 = transform(transform(ai0, plusIndex_2), plusOne_2)
+    val trans3_3 = transform(transform(ai1, plusIndex_1), plusOne_1)
+    assert(trans3_1.semanticEquals(trans3_2))
+    assert(!trans3_1.semanticEquals(trans3_3))
+  }
+
+  test("semanticEquals between ArraySort") {
+    val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+
+    val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+    val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
+
+    assert(arraySort(a0).semanticEquals(arraySort(a0)))
+    assert(arraySort(arrayStruct).semanticEquals(arraySort(arrayStruct)))
+
+    val sort1_1 = arraySort(a0, (left, right) => 
UnaryMinus(ArraySort.comparator(left, right)))
+    val sort1_2 = arraySort(a0, (right, left) => 
UnaryMinus(ArraySort.comparator(right, left)))
+    val sort1_3 = arraySort(a0, (right, left) => 
UnaryMinus(ArraySort.comparator(left, right)))
+    assert(sort1_1.semanticEquals(sort1_2))
+    assert(!sort1_1.semanticEquals(sort1_3))
+  }
+
+  test("semanticEquals between MapFilter") {

Review comment:
       Note that I don't add `semanticEquals` test for all higher functions, 
but just a few ones you can see. It might be too verbose for adding all, but 
let me know if you prefer to.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to