Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20915#discussion_r177973464 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala --- @@ -50,6 +51,80 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets(expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)) + matchedBuckets + } + + expr match { + case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + val valuesSet = list.map(e => e.eval(EmptyRow)) + valuesSet + .map(v => getMatchedBucketBitSet(a, v)) + .fold(new BitSet(numBuckets))(_ | _) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + val leftBuckets = getExpressionBuckets(left, bucketColumnName, numBuckets) + val rightBuckets = getExpressionBuckets(right, bucketColumnName, numBuckets) + + // if some expression in OR condition requires all buckets, return an empty BitSet + if (leftBuckets.cardinality() == 0 || rightBuckets.cardinality() == 0) { + new BitSet(numBuckets) + } else { + // return a BitSet that includes all required buckets + leftBuckets | rightBuckets + } + case _ => new BitSet(numBuckets) + } + } + + private def getBuckets(normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + + val matchedBuckets = normalizedFilters + .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) + .fold(new BitSet(numBuckets))(_ | _) + + val numBucketsSelected = if (matchedBuckets.cardinality() != 0) matchedBuckets.cardinality() --- End diff -- Should match common style in Spark: ```scala if (matchedBuckets.cardinality() != 0) { matchedBuckets.cardinality() } else { numBuckets } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org