tobixdev commented on code in PR #18254:
URL: https://github.com/apache/datafusion/pull/18254#discussion_r2465454174
##########
datafusion/physical-plan/src/recursive_query.rs:
##########
@@ -268,8 +273,10 @@ struct RecursiveQueryStream {
buffer: Vec<RecordBatch>,
/// Tracks the memory used by the buffer
reservation: MemoryReservation,
+ /// If the distinct flag is set, then we use this hash table to remove
duplicates from result and work tables
+ distinct_deduplicator: Option<DistinctDeduplicator>,
// /// Metrics.
Review Comment:
I think the `//` this was not intended when the original code was submitted.
Maybe we can make this a regular doc comment now that we're changing it.
##########
datafusion/physical-plan/src/recursive_query.rs:
##########
@@ -434,5 +452,55 @@ impl RecordBatchStream for RecursiveQueryStream {
}
}
+/// Deduplicator based on a hash table.
+struct DistinctDeduplicator {
+ /// Grouped rows used for distinct
+ group_values: Box<dyn GroupValues>,
+ reservation: MemoryReservation,
+ intern_output_buffer: Vec<usize>,
+}
+
+impl DistinctDeduplicator {
+ fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
+ let group_values = new_group_values(schema, &GroupOrdering::None)?;
+ let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
+ .register(task_context.memory_pool());
+ Ok(Self {
+ group_values,
+ reservation,
+ intern_output_buffer: Vec::new(),
+ })
+ }
+
+ fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
+ // We use the hash table to allocate new group ids.
+ // If they are new, i.e., if they have ids >= length before interning,
we keep them.
+ // We also detect duplicates by enforcing that group ids are
increasing.
+ let size_before = self.group_values.len();
+ self.intern_output_buffer.reserve(batch.num_rows());
+ self.group_values
+ .intern(batch.columns(), &mut self.intern_output_buffer)?;
+ let mask = are_increasing_mask(&self.intern_output_buffer,
size_before);
+ self.intern_output_buffer.clear();
+ // We update the reservation to reflect the new size of the hash table.
+ self.reservation.try_resize(self.group_values.size())?;
+ Ok(filter_record_batch(batch, &mask)?)
+ }
+}
+
+/// Return a mask, each element true if the value is greater than all previous
ones and greater or equal than the min_value
+fn are_increasing_mask(values: &[usize], mut min_value: usize) -> BooleanArray
{
Review Comment:
I think I understood what this function does, but I had a hard time with
`min_value`. Maybe we can be more explicit here. Just some suggestions:
input parameter: `min_value` -> `highest_group_id`
```rust
// Always update the min_value to do de-duplication within a record batch.
let mut min_value = highet_group_id;
```
May the integrating the comment in the doc comment for `are_increasing_mask`
is also more than enough.
I think this assumes that the group ids are assigned in-order within the
record batch but I think this is a valid assumption. Maybe someone more
familiar with the aggregation infrastructure has more information on that.
##########
datafusion/physical-plan/src/recursive_query.rs:
##########
@@ -434,5 +452,55 @@ impl RecordBatchStream for RecursiveQueryStream {
}
}
+/// Deduplicator based on a hash table.
+struct DistinctDeduplicator {
+ /// Grouped rows used for distinct
+ group_values: Box<dyn GroupValues>,
+ reservation: MemoryReservation,
+ intern_output_buffer: Vec<usize>,
+}
+
+impl DistinctDeduplicator {
+ fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
+ let group_values = new_group_values(schema, &GroupOrdering::None)?;
+ let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
+ .register(task_context.memory_pool());
+ Ok(Self {
+ group_values,
+ reservation,
+ intern_output_buffer: Vec::new(),
+ })
+ }
+
+ fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
+ // We use the hash table to allocate new group ids.
Review Comment:
I think we can make a version of that comment the doc comment for
`DistinctDeduplicator::deduplicate`
##########
datafusion/sqllogictest/test_files/cte.slt:
##########
@@ -1049,6 +1053,68 @@ physical_plan
05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST],
preserve_partitioning=[false]
06)------WorkTableExec: name=r
+# setup
+statement ok
+CREATE EXTERNAL TABLE closure STORED as CSV LOCATION
'../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header'
'true');
+
+# transitive closure with loop
+query II
+WITH RECURSIVE trans AS (
Review Comment:
🥳
##########
datafusion/physical-plan/src/recursive_query.rs:
##########
@@ -293,21 +304,28 @@ impl RecursiveQueryStream {
schema,
buffer: vec![],
reservation,
- _baseline_metrics: baseline_metrics,
- }
+ distinct_deduplicator,
+ baseline_metrics,
+ })
}
/// Push a clone of the given batch to the in memory buffer, and then
return
/// a poll with it.
fn push_batch(
mut self: std::pin::Pin<&mut Self>,
- batch: RecordBatch,
+ mut batch: RecordBatch,
) -> Poll<Option<Result<RecordBatch>>> {
+ let baseline_metrics = self.baseline_metrics.clone();
+ if let Some(deduplicator) = &mut self.distinct_deduplicator {
Review Comment:
Thanks for adding metrics!
I think we could also move the metrics part to `<RecursiveQueryStream as
Stream>::poll_next` as there is already a `TODO` for doing so. I believe it
would be fine to update the time metrics there even if there is no
deduplication going on but there might be different opinions.
Anyways, I think this is an improvement over the status quo.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]