This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 922844fff65 [SPARK-45564][SQL] Simplify 'DataFrameStatFunctions.bloomFilter' with 'BloomFilterAggregate' expression 922844fff65 is described below commit 922844fff65ac38fd93bd0c914dcc7e5cf879996 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Oct 17 10:11:36 2023 -0500 [SPARK-45564][SQL] Simplify 'DataFrameStatFunctions.bloomFilter' with 'BloomFilterAggregate' expression ### What changes were proposed in this pull request? Simplify 'DataFrameStatFunctions.bloomFilter' function with 'BloomFilterAggregate' expression ### Why are the changes needed? existing implementation was based on RDD, and it can be simplified by dataframe operations ### Does this PR introduce _any_ user-facing change? when the input parameters or datatypes are invalid, throw `AnalysisException` instead of `IllegalArgumentException` ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43391 from zhengruifeng/sql_reimpl_stat_bloomFilter. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../apache/spark/sql/DataFrameStatFunctions.scala | 68 +++++----------------- 1 file changed, 14 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9d4f83c53a3..de3b100cd6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ @@ -535,7 +537,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { - buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) + bloomFilter(Column(colName), expectedNumItems, fpp) } /** @@ -547,7 +549,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { - buildBloomFilter(col, expectedNumItems, -1L, fpp) + val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) + bloomFilter(col, expectedNumItems, numBits) } /** @@ -559,7 +562,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { - buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) + bloomFilter(Column(colName), expectedNumItems, numBits) } /** @@ -571,57 +574,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { - buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) - } - - private def buildBloomFilter(col: Column, expectedNumItems: Long, - numBits: Long, - fpp: Double): BloomFilter = { - val singleCol = df.select(col) - val colType = singleCol.schema.head.dataType - - require(colType == StringType || colType.isInstanceOf[IntegralType], - s"Bloom filter only supports string type and integral types, but got $colType.") - - val updater: (BloomFilter, InternalRow) => Unit = colType match { - // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` - // instead of `putString` to avoid unnecessary conversion. - case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) - case ByteType => (filter, row) => filter.putLong(row.getByte(0)) - case ShortType => (filter, row) => filter.putLong(row.getShort(0)) - case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) - case LongType => (filter, row) => filter.putLong(row.getLong(0)) - case _ => - throw new IllegalArgumentException( - s"Bloom filter only supports string type and integral types, " + - s"and does not support type $colType." - ) - } - - singleCol.queryExecution.toRdd.treeAggregate(null.asInstanceOf[BloomFilter])( - (filter: BloomFilter, row: InternalRow) => { - val theFilter = - if (filter == null) { - if (fpp.isNaN) { - BloomFilter.create(expectedNumItems, numBits) - } else { - BloomFilter.create(expectedNumItems, fpp) - } - } else { - filter - } - updater(theFilter, row) - theFilter - }, - (filter1, filter2) => { - if (filter1 == null) { - filter2 - } else if (filter2 == null) { - filter1 - } else { - filter1.mergeInPlace(filter2) - } - } + val bloomFilterAgg = new BloomFilterAggregate( + col.expr, + Literal(expectedNumItems, LongType), + Literal(numBits, LongType) ) + val bytes = df.select( + Column(bloomFilterAgg.toAggregateExpression(false)) + ).head().getAs[Array[Byte]](0) + bloomFilterAgg.deserialize(bytes) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org