This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4339e0c0e4d [SPARK-45707][SQL] Simplify 
`DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg`
4339e0c0e4d is described below

commit 4339e0c0e4d7e502ae6cafa90444cd153017cb1a
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Oct 29 22:42:01 2023 -0700

    [SPARK-45707][SQL] Simplify `DataFrameStatFunctions.countMinSketch` with 
`CountMinSketchAgg`
    
    ### What changes were proposed in this pull request?
    Simplify `DataFrameStatFunctions.countMinSketch` with `CountMinSketchAgg`
    
    ### Why are the changes needed?
    to make it consistent with sql functions
    
    ### Does this PR introduce _any_ user-facing change?
    
    better error messages: `IllegalArgumentException` -> `AnalysisException`
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43560 from zhengruifeng/sql_reimpl_stat_countMinSketch.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/DataFrameStatFunctions.scala  | 44 +++++++---------------
 .../org/apache/spark/sql/DataFrameStatSuite.scala  |  2 +-
 2 files changed, 14 insertions(+), 32 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index de3b100cd6a..f3690773f6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,9 +22,8 @@ import java.{lang => jl, util => ju}
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, 
CountMinSketchAgg}
 import org.apache.spark.sql.execution.stat._
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
@@ -483,7 +482,9 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): 
CountMinSketch = {
-    countMinSketch(col, CountMinSketch.create(depth, width, seed))
+    val eps = 2.0 / width
+    val confidence = 1 - 1 / Math.pow(2, depth)
+    countMinSketch(col, eps, confidence, seed)
   }
 
   /**
@@ -497,35 +498,16 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 2.0.0
    */
   def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): 
CountMinSketch = {
-    countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
-  }
-
-  private def countMinSketch(col: Column, zero: CountMinSketch): 
CountMinSketch = {
-    val singleCol = df.select(col)
-    val colType = singleCol.schema.head.dataType
-
-    val updater: (CountMinSketch, InternalRow) => Unit = colType match {
-      // For string type, we can get bytes of our `UTF8String` directly, and 
call the `addBinary`
-      // instead of `addString` to avoid unnecessary conversion.
-      case StringType => (sketch, row) => 
sketch.addBinary(row.getUTF8String(0).getBytes)
-      case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
-      case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
-      case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
-      case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
-      case _ =>
-        throw new IllegalArgumentException(
-          s"Count-min Sketch only supports string type and integral types, " +
-            s"and does not support type $colType."
-        )
-    }
-
-    singleCol.queryExecution.toRdd.aggregate(zero)(
-      (sketch: CountMinSketch, row: InternalRow) => {
-        updater(sketch, row)
-        sketch
-      },
-      (sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
+    val countMinSketchAgg = new CountMinSketchAgg(
+      col.expr,
+      Literal(eps, DoubleType),
+      Literal(confidence, DoubleType),
+      Literal(seed, IntegerType)
     )
+    val bytes = df.select(
+      Column(countMinSketchAgg.toAggregateExpression(false))
+    ).head().getAs[Array[Byte]](0)
+    countMinSketchAgg.deserialize(bytes)
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 1dece5c8285..430e3622102 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -508,7 +508,7 @@ class DataFrameStatSuite extends QueryTest with 
SharedSparkSession {
     assert(sketch4.relativeError() === 0.001 +- 1e04)
     assert(sketch4.confidence() === 0.99 +- 5e-3)
 
-    intercept[IllegalArgumentException] {
+    intercept[AnalysisException] {
       df.select($"id" cast DoubleType as "id")
         .stat
         .countMinSketch($"id", depth = 10, width = 20, seed = 42)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to