Github user wangyum commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21603#discussion_r198146352
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
 ---
    @@ -270,6 +270,11 @@ private[parquet] class ParquetFilters(pushDownDate: 
Boolean) {
           case sources.Not(pred) =>
             createFilter(schema, pred).map(FilterApi.not)
     
    +      case sources.In(name, values) if canMakeFilterOn(name) && 
values.length < 20 =>
    --- End diff --
    
    I have prepared a test case that you can verify it:
    ```scala
      test("Benchmark") {
        def benchmark(func: () => Unit): Long = {
          val start = System.currentTimeMillis()
          func()
          val end = System.currentTimeMillis()
          end - start
        }
        // scalastyle:off
        withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
          withTempPath { path =>
            Seq(1000000, 10000000).foreach { count =>
              Seq(1048576, 10485760, 104857600).foreach { blockSize =>
                spark.range(count).toDF().selectExpr("id", "cast(id as string) 
as d1",
                  "cast(id as double) as d2", "cast(id as float) as d3", 
"cast(id as int) as d4",
                  "cast(id as decimal(38)) as d5")
                  .coalesce(1).write.mode("overwrite")
                  .option("parquet.block.size", 
blockSize).parquet(path.getAbsolutePath)
                val df = spark.read.parquet(path.getAbsolutePath)
                println(s"path: ${path.getAbsolutePath}")
                Seq(1000, 100, 10, 1).foreach { ratio =>
                  println(s"##########[ count: $count, blockSize: $blockSize, 
ratio: $ratio ]#########")
                  var i = 1
                  while (i < 300) {
                    val filter = Range(0, i).map(r => 
scala.util.Random.nextInt(count / ratio))
                    i += 4
    
                    sql("set spark.sql.parquet.pushdown.inFilterThreshold=1")
                    val vanillaTime = benchmark(() => df.where(s"id 
in(${filter.mkString(",")})").count())
                    sql("set spark.sql.parquet.pushdown.inFilterThreshold=1000")
                    val pushDownTime = benchmark(() => df.where(s"id 
in(${filter.mkString(",")})").count())
    
                    if (pushDownTime > vanillaTime) {
                      println(s"vanilla is better, threshold: ${filter.size}, 
$pushDownTime, $vanillaTime")
                    } else {
                      println(s"push down is better, threshold: ${filter.size}, 
$pushDownTime, $vanillaTime")
                    }
                  }
                }
              }
            }
          }
        }
      }
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to