Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20915#discussion_r192953586
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 ---
    @@ -50,6 +51,84 @@ import org.apache.spark.sql.execution.SparkPlan
      *     and add it.  Proceed to the next file.
      */
     object FileSourceStrategy extends Strategy with Logging {
    +
    +  // should prune buckets iff num buckets is greater than 1 and there is 
only one bucket column
    +  private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean 
= {
    +    bucketSpec match {
    +      case Some(spec) => spec.bucketColumnNames.length == 1 && 
spec.numBuckets > 1
    +      case None => false
    +    }
    +  }
    +
    +  private def getExpressionBuckets(expr: Expression,
    +                                   bucketColumnName: String,
    +                                   numBuckets: Int): BitSet = {
    +
    +    def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = {
    +      val matchedBuckets = new BitSet(numBuckets)
    +      matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, 
numBuckets, v))
    +      matchedBuckets
    +    }
    +
    +    expr match {
    +      case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == 
bucketColumnName =>
    +        getMatchedBucketBitSet(a, v)
    +      case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == 
bucketColumnName =>
    +        getMatchedBucketBitSet(a, v)
    +      case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if 
a.name == bucketColumnName =>
    +        getMatchedBucketBitSet(a, v)
    +      case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if 
a.name == bucketColumnName =>
    +        getMatchedBucketBitSet(a, v)
    +      case expressions.In(a: Attribute, list)
    +        if list.forall(_.isInstanceOf[Literal]) && a.name == 
bucketColumnName =>
    +        val valuesSet = list.map(e => e.eval(EmptyRow))
    +        valuesSet
    +          .map(v => getMatchedBucketBitSet(a, v))
    +          .fold(new BitSet(numBuckets))(_ | _)
    --- End diff --
    
    can't we create one bit set for all the matched buckets, instead of 
creating many bit sets and merge them?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to