Github user liancheng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/13846#discussion_r68753763
  
    --- Diff: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
 ---
    @@ -23,54 +23,111 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
     import org.apache.spark.sql.catalyst.dsl.expressions._
     import org.apache.spark.sql.catalyst.dsl.plans._
     import org.apache.spark.sql.catalyst.encoders.{encoderFor, 
ExpressionEncoder}
    +import org.apache.spark.sql.catalyst.expressions.{BoundReference, 
ReferenceToExpressions}
     import org.apache.spark.sql.catalyst.plans.PlanTest
     import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan}
     import org.apache.spark.sql.catalyst.rules.RuleExecutor
    -import org.apache.spark.sql.types.BooleanType
    +import org.apache.spark.sql.types.{BooleanType, ObjectType}
     
     class TypedFilterOptimizationSuite extends PlanTest {
       object Optimize extends RuleExecutor[LogicalPlan] {
         val batches =
           Batch("EliminateSerialization", FixedPoint(50),
             EliminateSerialization) ::
    -      Batch("EmbedSerializerInFilter", FixedPoint(50),
    -        EmbedSerializerInFilter) :: Nil
    +      Batch("CombineTypedFilters", FixedPoint(50),
    +        CombineTypedFilters) :: Nil
       }
     
       implicit private def productEncoder[T <: Product : TypeTag] = 
ExpressionEncoder[T]()
     
    -  test("back to back filter") {
    +  test("filter after serialize") {
         val input = LocalRelation('_1.int, '_2.int)
    -    val f1 = (i: (Int, Int)) => i._1 > 0
    -    val f2 = (i: (Int, Int)) => i._2 > 0
    +    val f = (i: (Int, Int)) => i._1 > 0
     
    -    val query = input.filter(f1).filter(f2).analyze
    +    val query = input
    +      .deserialize[(Int, Int)]
    +      .serialize[(Int, Int)]
    +      .filter(f).analyze
     
         val optimized = Optimize.execute(query)
     
    -    val expected = input.deserialize[(Int, Int)]
    -      .where(callFunction(f1, BooleanType, 'obj))
    -      .select('obj.as("obj"))
    -      .where(callFunction(f2, BooleanType, 'obj))
    +    val expected = input
    +      .deserialize[(Int, Int)]
    +      .where(callFunction(f, BooleanType, 'obj))
           .serialize[(Int, Int)].analyze
     
         comparePlans(optimized, expected)
       }
     
    -  // TODO: Remove this after we completely fix SPARK-15632 by adding 
optimization rules
    -  // for typed filters.
    -  ignore("embed deserializer in typed filter condition if there is only 
one filter") {
    +  test("filter after serialize with object change") {
    +    val input = LocalRelation('_1.int, '_2.int)
    +    val f = (i: OtherTuple) => i._1 > 0
    +
    +    val query = input
    +      .deserialize[(Int, Int)]
    +      .serialize[(Int, Int)]
    +      .filter(f).analyze
    +    val optimized = Optimize.execute(query)
    +    comparePlans(optimized, query)
    +  }
    +
    +  test("filter before deserialize") {
         val input = LocalRelation('_1.int, '_2.int)
         val f = (i: (Int, Int)) => i._1 > 0
     
    -    val query = input.filter(f).analyze
    +    val query = input
    +      .filter(f)
    +      .deserialize[(Int, Int)]
    +      .serialize[(Int, Int)].analyze
    +
    +    val optimized = Optimize.execute(query)
    +
    +    val expected = input
    +      .deserialize[(Int, Int)]
    +      .where(callFunction(f, BooleanType, 'obj))
    +      .serialize[(Int, Int)].analyze
    +
    +    comparePlans(optimized, expected)
    +  }
    +
    +  test("filter before deserialize with object change") {
    +    val input = LocalRelation('_1.int, '_2.int)
    +    val f = (i: OtherTuple) => i._1 > 0
    +
    +    val query = input
    +      .filter(f)
    +      .deserialize[(Int, Int)]
    +      .serialize[(Int, Int)].analyze
    +    val optimized = Optimize.execute(query)
    +    comparePlans(optimized, query)
    +  }
    +
    +  test("back to back filter") {
    +    val input = LocalRelation('_1.int, '_2.int)
    +    val f1 = (i: (Int, Int)) => i._1 > 0
    +    val f2 = (i: (Int, Int)) => i._2 > 0
    +
    +    val query = input.filter(f1).filter(f2).analyze
     
         val optimized = Optimize.execute(query)
     
         val deserializer = UnresolvedDeserializer(encoderFor[(Int, 
Int)].deserializer)
    -    val condition = callFunction(f, BooleanType, deserializer)
    -    val expected = input.where(condition).select('_1.as("_1"), 
'_2.as("_2")).analyze
    +    val boundReference = BoundReference(0, ObjectType(classOf[(Int, 
Int)]), nullable = false)
    +    val callFunc1 = callFunction(f1, BooleanType, boundReference)
    +    val callFunc2 = callFunction(f2, BooleanType, boundReference)
    +    val condition = ReferenceToExpressions(callFunc2 && callFunc1, 
deserializer :: Nil)
    +    val expected = input.where(condition).analyze
     
         comparePlans(optimized, expected)
       }
    +
    +  test("back to back filter with object change") {
    --- End diff --
    
    Nit: "back to back filters with different object types"


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to