mbutrovich commented on code in PR #4390:
URL: https://github.com/apache/datafusion-comet/pull/4390#discussion_r3284098242
##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -272,16 +272,59 @@ impl SparkBloomFilter {
}
pub fn state_as_bytes(&self) -> Vec<u8> {
- self.bits.to_bytes()
+ self.spark_serialization()
}
pub fn merge_filter(&mut self, other: &[u8]) {
Review Comment:
`merge_filter` is `pub fn ... -> ()` and panics via `assert_eq!` on every
header mismatch. Its only caller is `Accumulator::merge_batch` in
`bloom_filter_agg.rs:176`, which already returns `Result`. Threading these
through as `DataFusionError::Internal` would let a corrupt or truncated
intermediate buffer surface as a query failure rather than crashing the
executor process.
##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -272,16 +272,59 @@ impl SparkBloomFilter {
}
pub fn state_as_bytes(&self) -> Vec<u8> {
- self.bits.to_bytes()
+ self.spark_serialization()
}
pub fn merge_filter(&mut self, other: &[u8]) {
+ let mut offset = 0;
+
+ let version_int = read_num_be_bytes!(i32, 4, other[offset..]);
+ offset += 4;
+ assert_eq!(
+ version_int,
+ self.version.to_int(),
+ "BloomFilter merge: version mismatch (got {}, expected {})",
+ version_int,
+ self.version.to_int(),
+ );
+
+ let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32;
+ offset += 4;
+ assert_eq!(
+ num_hash, self.num_hash_functions,
+ "BloomFilter merge: num_hash_functions mismatch (got {}, expected
{})",
+ num_hash, self.num_hash_functions,
+ );
+
+ if let SparkBloomFilterVersion::V2 = self.version {
+ let seed = read_num_be_bytes!(i32, 4, other[offset..]);
+ offset += 4;
+ assert_eq!(
+ seed, self.seed,
+ "BloomFilter merge: seed mismatch (got {}, expected {})",
+ seed, self.seed,
+ );
+ }
+
+ let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize;
+ offset += 4;
assert_eq!(
- other.len(),
- self.bits.byte_size(),
- "Cannot merge SparkBloomFilters with different lengths."
+ num_words,
+ self.bits.word_size(),
+ "BloomFilter merge: num_words mismatch (got {}, expected {})",
+ num_words,
+ self.bits.word_size(),
);
- self.bits.merge_bits(other);
+
+ let words = self.bits.data();
Review Comment:
This clones `self.bits.data()` into `words`, builds a new `Vec` of the same
size, and then constructs a fresh `SparkBitArray` whose `new` re-scans the
words to recompute `bit_count`. Two allocations and two passes per merge. An
in-place variant that ORs into `self.bits.data` directly and accumulates
`bit_count` in the same loop would avoid both.
##########
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala:
##########
@@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)
+ override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
Review Comment:
The body here is identical to `CometHashAggregateExec.getSupportLevel` at
`operators.scala:1658-1670`, including the conf names. That is fine for the
test-knob purpose called out in the comment, but
`COMET_ENABLE_PARTIAL_HASH_AGGREGATE` and `COMET_ENABLE_FINAL_HASH_AGGREGATE`
now gate both `HashAggregateExec` and `ObjectHashAggregateExec`. As a
follow-up, consider renaming to `COMET_ENABLE_PARTIAL_AGGREGATE` /
`COMET_ENABLE_FINAL_AGGREGATE` so the conf names match the scope.
##########
native/spark-expr/src/bloom_filter/spark_bloom_filter.rs:
##########
@@ -396,4 +439,96 @@ mod tests {
buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes
let _ = SparkBloomFilter::from(buf.as_slice());
}
+
+ /// Two V1 filters with identical parameters. Populate the first,
serialize via
+ /// state_as_bytes, merge into the empty second, and verify the second
contains
+ /// everything the first did. Exercises the aggregator state → merge_batch
path.
+ #[test]
+ fn state_round_trip_v1_merge() {
+ let num_bits = 1024;
+ let num_hash = optimal_num_hash_functions(100, num_bits);
+ let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1,
num_hash, num_bits, 0);
+ for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
+ a.put_long(v);
+ }
+
+ let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1,
num_hash, num_bits, 0);
+ b.merge_filter(&a.state_as_bytes());
Review Comment:
Nice that these go through `state_as_bytes` then `merge_filter`. The
pre-existing `v1_round_trip` and `v2_round_trip` tests above use
`SparkBloomFilter::from`, which was always header-aware, so they would not have
caught the bug this PR fixes. The new cases are exactly the right round-trip
path
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]