This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new a4bd7583ce6 [SPARK-42664][CONNECT] Support `bloomFilter` function for 
`DataFrameStatFunctions`
a4bd7583ce6 is described below

commit a4bd7583ce6d680f0091519007e48894d594b9f6
Author: yangjie01 <yangji...@baidu.com>
AuthorDate: Tue Aug 15 19:12:03 2023 +0200

    [SPARK-42664][CONNECT] Support `bloomFilter` function for 
`DataFrameStatFunctions`
    
    ### What changes were proposed in this pull request?
    This is pr using `BloomFilterAggregate` to implement `bloomFilter` function 
for `DataFrameStatFunctions`.
    
    ### Why are the changes needed?
    Add Spark connect jvm client api coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    
    - Add new test
    - Manually check Scala 2.13
    
    Closes #42414 from LuciferYang/SPARK-42664-backup.
    
    Authored-by: yangjie01 <yangji...@baidu.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit b9f11143d058ad05dcda2138133471c9500c8b92)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../org/apache/spark/util/sketch/BloomFilter.java  |  4 +-
 .../apache/spark/sql/DataFrameStatFunctions.scala  | 88 +++++++++++++++++++++-
 .../spark/sql/ClientDataFrameStatSuite.scala       | 87 +++++++++++++++++++++
 .../CheckConnectJvmClientCompatibility.scala       |  3 -
 .../sql/connect/planner/SparkConnectPlanner.scala  | 31 ++++++++
 .../aggregate/BloomFilterAggregate.scala           | 43 ++++++++++-
 6 files changed, 248 insertions(+), 8 deletions(-)

diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 5c01841e501..f3c2b05e7af 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -199,9 +199,9 @@ public abstract class BloomFilter {
    * See 
http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for 
the formula.
    *
    * @param n expected insertions (must be positive)
-   * @param p false positive rate (must be 0 < p < 1)
+   * @param p false positive rate (must be 0 &lt; p &lt; 1)
    */
-  private static long optimalNumOfBits(long n, double p) {
+  public static long optimalNumOfBits(long n, double p) {
     return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
   }
 
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 0d4372b8738..4d35b4e8767 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -18,14 +18,16 @@
 package org.apache.spark.sql
 
 import java.{lang => jl, util => ju}
+import java.io.ByteArrayInputStream
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.SparkException
 import org.apache.spark.connect.proto.{Relation, StatSampleBy}
 import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, 
BinaryEncoder, PrimitiveDoubleEncoder}
 import org.apache.spark.sql.functions.lit
-import org.apache.spark.util.sketch.CountMinSketch
+import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
 
 /**
  * Statistic functions for `DataFrame`s.
@@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] 
(sparkSession: SparkSession, roo
     }
     CountMinSketch.readFrom(ds.head())
   }
+
+  /**
+   * Builds a Bloom filter over a specified column.
+   *
+   * @param colName
+   *   name of the column over which the filter is built
+   * @param expectedNumItems
+   *   expected number of items which will be put into the filter.
+   * @param fpp
+   *   expected false positive probability of the filter.
+   * @since 3.5.0
+   */
+  def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
+    buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
+  }
+
+  /**
+   * Builds a Bloom filter over a specified column.
+   *
+   * @param col
+   *   the column over which the filter is built
+   * @param expectedNumItems
+   *   expected number of items which will be put into the filter.
+   * @param fpp
+   *   expected false positive probability of the filter.
+   * @since 3.5.0
+   */
+  def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
+    buildBloomFilter(col, expectedNumItems, -1L, fpp)
+  }
+
+  /**
+   * Builds a Bloom filter over a specified column.
+   *
+   * @param colName
+   *   name of the column over which the filter is built
+   * @param expectedNumItems
+   *   expected number of items which will be put into the filter.
+   * @param numBits
+   *   expected number of bits of the filter.
+   * @since 3.5.0
+   */
+  def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
+    buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
+  }
+
+  /**
+   * Builds a Bloom filter over a specified column.
+   *
+   * @param col
+   *   the column over which the filter is built
+   * @param expectedNumItems
+   *   expected number of items which will be put into the filter.
+   * @param numBits
+   *   expected number of bits of the filter.
+   * @since 3.5.0
+   */
+  def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
+    buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
+  }
+
+  private def buildBloomFilter(
+      col: Column,
+      expectedNumItems: Long,
+      numBits: Long,
+      fpp: Double): BloomFilter = {
+    def numBitsValue: Long = if (!fpp.isNaN) {
+      BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
+    } else {
+      numBits
+    }
+
+    if (fpp <= 0d || fpp >= 1d) {
+      throw new SparkException("False positive probability must be within 
range (0.0, 1.0)")
+    }
+    val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), 
lit(numBitsValue))
+
+    val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
+      builder.getProjectBuilder
+        .setInput(root)
+        .addExpressions(agg.expr)
+    }
+    BloomFilter.readFrom(new ByteArrayInputStream(ds.head()))
+  }
 }
 
 private object DataFrameStatFunctions {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
index 62ff21332f1..7035dc99148 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
@@ -176,4 +176,91 @@ class ClientDataFrameStatSuite extends RemoteSparkSession {
     assert(sketch.relativeError() === 0.001)
     assert(sketch.confidence() === 0.99 +- 5e-3)
   }
+
+  test("Bloom filter -- Long Column") {
+    val session = spark
+    import session.implicits._
+    val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong)
+    val df = data.toDF("id")
+    val negativeValues = Seq(-11, 1021, 32767).map(_.toLong)
+    checkBloomFilter(data, negativeValues, df)
+  }
+
+  test("Bloom filter -- Int Column") {
+    val session = spark
+    import session.implicits._
+    val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997)
+    val df = data.toDF("id")
+    val negativeValues = Seq(-11, 1021, 32767)
+    checkBloomFilter(data, negativeValues, df)
+  }
+
+  test("Bloom filter -- Short Column") {
+    val session = spark
+    import session.implicits._
+    val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort)
+    val df = data.toDF("id")
+    val negativeValues = Seq(-11, 1021, 32767).map(_.toShort)
+    checkBloomFilter(data, negativeValues, df)
+  }
+
+  test("Bloom filter -- Byte Column") {
+    val session = spark
+    import session.implicits._
+    val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte)
+    val df = data.toDF("id")
+    val negativeValues = Seq(-101, 55, 113).map(_.toByte)
+    checkBloomFilter(data, negativeValues, df)
+  }
+
+  test("Bloom filter -- String Column") {
+    val session = spark
+    import session.implicits._
+    val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString)
+    val df = data.toDF("id")
+    val negativeValues = Seq(-11, 1021, 32767).map(_.toString)
+    checkBloomFilter(data, negativeValues, df)
+  }
+
+  private def checkBloomFilter(
+      data: Seq[Any],
+      notContainValues: Seq[Any],
+      df: DataFrame): Unit = {
+    val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
+    assert(filter1.expectedFpp() - 0.03 < 1e-3)
+    assert(data.forall(filter1.mightContain))
+    assert(notContainValues.forall(n => !filter1.mightContain(n)))
+    val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5)
+    assert(filter2.bitSize() == 64 * 5)
+    assert(data.forall(filter2.mightContain))
+    assert(notContainValues.forall(n => !filter2.mightContain(n)))
+  }
+
+  test("Bloom filter -- Wrong dataType Column") {
+    val session = spark
+    import session.implicits._
+    val data = Range(0, 1000).map(_.toDouble)
+    val message = intercept[AnalysisException] {
+      data.toDF("id").stat.bloomFilter("id", 1000, 0.03)
+    }.getMessage
+    assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE"))
+  }
+
+  test("Bloom filter test invalid inputs") {
+    val df = spark.range(1000).toDF("id")
+    val message1 = intercept[SparkException] {
+      df.stat.bloomFilter("id", -1000, 100)
+    }.getMessage
+    assert(message1.contains("Expected insertions must be positive"))
+
+    val message2 = intercept[SparkException] {
+      df.stat.bloomFilter("id", 1000, -100)
+    }.getMessage
+    assert(message2.contains("Number of bits must be positive"))
+
+    val message3 = intercept[SparkException] {
+      df.stat.bloomFilter("id", 1000, -1.0)
+    }.getMessage
+    assert(message3.contains("False positive probability must be within range 
(0.0, 1.0)"))
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7356d4daa79..8f226eb2f7e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility {
       // DataFrameNaFunctions
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),
 
-      // DataFrameStatFunctions
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
-
       // Dataset
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.Dataset$" // private[sql]
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc77c52ef46..f3e87b7067d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -48,6 +48,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
@@ -1731,6 +1732,36 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
         Some(Lead(children.head, children(1), children(2), ignoreNulls))
 
+      case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
+        // [col, expectedNumItems: Long, numBits: Long]
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+
+        // Check expectedNumItems is LongType and value greater than 0L
+        val expectedNumItemsExpr = children(1)
+        val expectedNumItems = expectedNumItemsExpr match {
+          case Literal(l: Long, LongType) => l
+          case _ =>
+            throw InvalidPlanInput("Expected insertions must be long literal.")
+        }
+        if (expectedNumItems <= 0L) {
+          throw InvalidPlanInput("Expected insertions must be positive.")
+        }
+
+        val numBitsExpr = children(2)
+        // Check numBits is LongType and value greater than 0L
+        numBitsExpr match {
+          case Literal(numBits: Long, LongType) =>
+            if (numBits <= 0L) {
+              throw InvalidPlanInput("Number of bits must be positive.")
+            }
+          case _ =>
+            throw InvalidPlanInput("Number of bits must be long literal.")
+        }
+
+        Some(
+          new BloomFilterAggregate(children.head, expectedNumItemsExpr, 
numBitsExpr)
+            .toAggregateExpression())
+
       case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         val timeCol = children.head
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
index 980785e764c..7cba462ce2c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike
 import org.apache.spark.sql.internal.SQLConf
 import 
org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, 
RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.BloomFilter
 
 /**
@@ -78,7 +79,7 @@ case class BloomFilterAggregate(
             "exprName" -> "estimatedNumItems or numBits"
           )
         )
-      case (LongType, LongType, LongType) =>
+      case (LongType | IntegerType | ShortType | ByteType | StringType, 
LongType, LongType) =>
         if (!estimatedNumItemsExpression.foldable) {
           DataTypeMismatch(
             errorSubClass = "NON_FOLDABLE_INPUT",
@@ -150,6 +151,15 @@ case class BloomFilterAggregate(
     Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
       SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
 
+  // Mark as lazy so that `updater` is not evaluated during tree 
transformation.
+  private lazy val updater: BloomFilterUpdater = child.dataType match {
+    case LongType => LongUpdater
+    case IntegerType => IntUpdater
+    case ShortType => ShortUpdater
+    case ByteType => ByteUpdater
+    case StringType => BinaryUpdater
+  }
+
   override def first: Expression = child
 
   override def second: Expression = estimatedNumItemsExpression
@@ -174,7 +184,7 @@ case class BloomFilterAggregate(
     if (value == null) {
       return buffer
     }
-    buffer.putLong(value.asInstanceOf[Long])
+    updater.update(buffer, value)
     buffer
   }
 
@@ -224,3 +234,32 @@ object BloomFilterAggregate {
     bloomFilter
   }
 }
+
+private trait BloomFilterUpdater {
+  def update(bf: BloomFilter, v: Any): Boolean
+}
+
+private object LongUpdater extends BloomFilterUpdater with Serializable {
+  override def update(bf: BloomFilter, v: Any): Boolean =
+    bf.putLong(v.asInstanceOf[Long])
+}
+
+private object IntUpdater extends BloomFilterUpdater with Serializable {
+  override def update(bf: BloomFilter, v: Any): Boolean =
+    bf.putLong(v.asInstanceOf[Int])
+}
+
+private object ShortUpdater extends BloomFilterUpdater with Serializable {
+  override def update(bf: BloomFilter, v: Any): Boolean =
+    bf.putLong(v.asInstanceOf[Short])
+}
+
+private object ByteUpdater extends BloomFilterUpdater with Serializable {
+  override def update(bf: BloomFilter, v: Any): Boolean =
+    bf.putLong(v.asInstanceOf[Byte])
+}
+
+private object BinaryUpdater extends BloomFilterUpdater with Serializable {
+  override def update(bf: BloomFilter, v: Any): Boolean =
+    bf.putBinary(v.asInstanceOf[UTF8String].getBytes)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to