This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 922844fff65 [SPARK-45564][SQL] Simplify 
'DataFrameStatFunctions.bloomFilter' with 'BloomFilterAggregate' expression
922844fff65 is described below

commit 922844fff65ac38fd93bd0c914dcc7e5cf879996
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Oct 17 10:11:36 2023 -0500

    [SPARK-45564][SQL] Simplify 'DataFrameStatFunctions.bloomFilter' with 
'BloomFilterAggregate' expression
    
    ### What changes were proposed in this pull request?
    Simplify 'DataFrameStatFunctions.bloomFilter' function with 
'BloomFilterAggregate' expression
    
    ### Why are the changes needed?
    existing implementation was based on RDD, and it can be simplified by 
dataframe operations
    
    ### Does this PR introduce _any_ user-facing change?
    when the input parameters or datatypes are invalid, throw 
`AnalysisException` instead of `IllegalArgumentException`
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43391 from zhengruifeng/sql_reimpl_stat_bloomFilter.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../apache/spark/sql/DataFrameStatFunctions.scala  | 68 +++++-----------------
 1 file changed, 14 insertions(+), 54 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9d4f83c53a3..de3b100cd6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -23,6 +23,8 @@ import scala.jdk.CollectionConverters._
 
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.execution.stat._
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
@@ -535,7 +537,7 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
-    buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
+    bloomFilter(Column(colName), expectedNumItems, fpp)
   }
 
   /**
@@ -547,7 +549,8 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
-    buildBloomFilter(col, expectedNumItems, -1L, fpp)
+    val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
+    bloomFilter(col, expectedNumItems, numBits)
   }
 
   /**
@@ -559,7 +562,7 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
-    buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
+    bloomFilter(Column(colName), expectedNumItems, numBits)
   }
 
   /**
@@ -571,57 +574,14 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
-    buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
-  }
-
-  private def buildBloomFilter(col: Column, expectedNumItems: Long,
-                               numBits: Long,
-                               fpp: Double): BloomFilter = {
-    val singleCol = df.select(col)
-    val colType = singleCol.schema.head.dataType
-
-    require(colType == StringType || colType.isInstanceOf[IntegralType],
-      s"Bloom filter only supports string type and integral types, but got 
$colType.")
-
-    val updater: (BloomFilter, InternalRow) => Unit = colType match {
-      // For string type, we can get bytes of our `UTF8String` directly, and 
call the `putBinary`
-      // instead of `putString` to avoid unnecessary conversion.
-      case StringType => (filter, row) => 
filter.putBinary(row.getUTF8String(0).getBytes)
-      case ByteType => (filter, row) => filter.putLong(row.getByte(0))
-      case ShortType => (filter, row) => filter.putLong(row.getShort(0))
-      case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
-      case LongType => (filter, row) => filter.putLong(row.getLong(0))
-      case _ =>
-        throw new IllegalArgumentException(
-          s"Bloom filter only supports string type and integral types, " +
-            s"and does not support type $colType."
-        )
-    }
-
-    
singleCol.queryExecution.toRdd.treeAggregate(null.asInstanceOf[BloomFilter])(
-      (filter: BloomFilter, row: InternalRow) => {
-        val theFilter =
-          if (filter == null) {
-            if (fpp.isNaN) {
-              BloomFilter.create(expectedNumItems, numBits)
-            } else {
-              BloomFilter.create(expectedNumItems, fpp)
-            }
-          } else {
-            filter
-          }
-        updater(theFilter, row)
-        theFilter
-      },
-      (filter1, filter2) => {
-        if (filter1 == null) {
-          filter2
-        } else if (filter2 == null) {
-          filter1
-        } else {
-          filter1.mergeInPlace(filter2)
-        }
-      }
+    val bloomFilterAgg = new BloomFilterAggregate(
+      col.expr,
+      Literal(expectedNumItems, LongType),
+      Literal(numBits, LongType)
     )
+    val bytes = df.select(
+      Column(bloomFilterAgg.toAggregateExpression(false))
+    ).head().getAs[Array[Byte]](0)
+    bloomFilterAgg.deserialize(bytes)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to