mbutrovich commented on code in PR #4390:
URL: https://github.com/apache/datafusion-comet/pull/4390#discussion_r3284098242


##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -272,16 +272,59 @@ impl SparkBloomFilter {
     }
 
     pub fn state_as_bytes(&self) -> Vec<u8> {
-        self.bits.to_bytes()
+        self.spark_serialization()
     }
 
     pub fn merge_filter(&mut self, other: &[u8]) {

Review Comment:
   `merge_filter` is `pub fn ... -> ()` and panics via `assert_eq!` on every 
header mismatch. Its only caller is `Accumulator::merge_batch` in 
`bloom_filter_agg.rs:176`, which already returns `Result`. Threading these 
through as `DataFusionError::Internal` would let a corrupt or truncated 
intermediate buffer surface as a query failure rather than crashing the 
executor process.



##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -272,16 +272,59 @@ impl SparkBloomFilter {
     }
 
     pub fn state_as_bytes(&self) -> Vec<u8> {
-        self.bits.to_bytes()
+        self.spark_serialization()
     }
 
     pub fn merge_filter(&mut self, other: &[u8]) {
+        let mut offset = 0;
+
+        let version_int = read_num_be_bytes!(i32, 4, other[offset..]);
+        offset += 4;
+        assert_eq!(
+            version_int,
+            self.version.to_int(),
+            "BloomFilter merge: version mismatch (got {}, expected {})",
+            version_int,
+            self.version.to_int(),
+        );
+
+        let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32;
+        offset += 4;
+        assert_eq!(
+            num_hash, self.num_hash_functions,
+            "BloomFilter merge: num_hash_functions mismatch (got {}, expected 
{})",
+            num_hash, self.num_hash_functions,
+        );
+
+        if let SparkBloomFilterVersion::V2 = self.version {
+            let seed = read_num_be_bytes!(i32, 4, other[offset..]);
+            offset += 4;
+            assert_eq!(
+                seed, self.seed,
+                "BloomFilter merge: seed mismatch (got {}, expected {})",
+                seed, self.seed,
+            );
+        }
+
+        let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize;
+        offset += 4;
         assert_eq!(
-            other.len(),
-            self.bits.byte_size(),
-            "Cannot merge SparkBloomFilters with different lengths."
+            num_words,
+            self.bits.word_size(),
+            "BloomFilter merge: num_words mismatch (got {}, expected {})",
+            num_words,
+            self.bits.word_size(),
         );
-        self.bits.merge_bits(other);
+
+        let words = self.bits.data();

Review Comment:
   This clones `self.bits.data()` into `words`, builds a new `Vec` of the same 
size, and then constructs a fresh `SparkBitArray` whose `new` re-scans the 
words to recompute `bit_count`. Two allocations and two passes per merge. An 
in-place variant that ORs into `self.bits.data` directly and accumulates 
`bit_count` in the same loop would avoid both.



##########
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala:
##########
@@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec
   override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
     CometConf.COMET_EXEC_AGGREGATE_ENABLED)
 
+  override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {

Review Comment:
   The body here is identical to `CometHashAggregateExec.getSupportLevel` at 
`operators.scala:1658-1670`, including the conf names. That is fine for the 
test-knob purpose called out in the comment, but 
`COMET_ENABLE_PARTIAL_HASH_AGGREGATE` and `COMET_ENABLE_FINAL_HASH_AGGREGATE` 
now gate both `HashAggregateExec` and `ObjectHashAggregateExec`. As a 
follow-up, consider renaming to `COMET_ENABLE_PARTIAL_AGGREGATE` / 
`COMET_ENABLE_FINAL_AGGREGATE` so the conf names match the scope.



##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -396,4 +439,96 @@ mod tests {
         buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes
         let _ = SparkBloomFilter::from(buf.as_slice());
     }
+
+    /// Two V1 filters with identical parameters. Populate the first, 
serialize via
+    /// state_as_bytes, merge into the empty second, and verify the second 
contains
+    /// everything the first did. Exercises the aggregator state → merge_batch 
path.
+    #[test]
+    fn state_round_trip_v1_merge() {
+        let num_bits = 1024;
+        let num_hash = optimal_num_hash_functions(100, num_bits);
+        let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 
num_hash, num_bits, 0);
+        for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
+            a.put_long(v);
+        }
+
+        let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 
num_hash, num_bits, 0);
+        b.merge_filter(&a.state_as_bytes());

Review Comment:
   Nice that these go through `state_as_bytes` then `merge_filter`. The 
pre-existing `v1_round_trip` and `v2_round_trip` tests above use 
`SparkBloomFilter::from`, which was always header-aware, so they would not have 
caught the bug this PR fixes. The new cases are exactly the right round-trip 
path



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to