Jimexist commented on a change in pull request #375:
URL: https://github.com/apache/arrow-datafusion/pull/375#discussion_r638361758



##########
File path: datafusion/src/physical_plan/windows.rs
##########
@@ -186,10 +273,260 @@ impl ExecutionPlan for WindowAggExec {
             ));
         }
 
-        // let input = self.input.execute(0).await?;
+        let input = self.input.execute(partition).await?;
 
-        Err(DataFusionError::NotImplemented(
-            "WindowAggExec::execute".to_owned(),
-        ))
+        let stream = Box::pin(WindowAggStream::new(
+            self.schema.clone(),
+            self.window_expr.clone(),
+            input,
+        ));
+        Ok(stream)
+    }
+}
+
+pin_project! {
+    /// stream for window aggregation plan
+    pub struct WindowAggStream {
+        schema: SchemaRef,
+        #[pin]
+        output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
+        finished: bool,
+    }
+}
+
+type WindowAccumulatorItem = Box<dyn WindowAccumulator>;
+
+fn window_expressions(
+    window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
+    Ok(window_expr
+        .iter()
+        .map(|expr| expr.expressions())
+        .collect::<Vec<_>>())
+}
+
+fn window_aggregate_batch(
+    batch: &RecordBatch,
+    window_accumulators: &mut [WindowAccumulatorItem],
+    expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+) -> Result<Vec<Option<Vec<ScalarValue>>>> {
+    // 1.1 iterate accumulators and respective expressions together
+    // 1.2 evaluate expressions
+    // 1.3 update / merge window accumulators with the expressions' values
+
+    // 1.1
+    window_accumulators
+        .iter_mut()
+        .zip(expressions)
+        .map(|(window_acc, expr)| {
+            // 1.2
+            let values = &expr
+                .iter()
+                .map(|e| e.evaluate(batch))
+                .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+                .collect::<Result<Vec<_>>>()?;
+
+            window_acc.scan_batch(values)
+        })
+        .into_iter()
+        .collect::<Result<Vec<_>>>()
+}
+
+/// returns a vector of ArrayRefs, where each entry corresponds to either the
+/// final value (mode = Final) or states (mode = Partial)
+fn finalize_window_aggregation(
+    window_accumulators: &[WindowAccumulatorItem],
+) -> Result<Vec<Option<ScalarValue>>> {
+    window_accumulators
+        .iter()
+        .map(|window_accumulator| window_accumulator.evaluate())
+        .collect::<Result<Vec<_>>>()
+}
+
+fn create_window_accumulators(
+    window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Vec<WindowAccumulatorItem>> {
+    window_expr
+        .iter()
+        .map(|expr| expr.create_accumulator())
+        .collect::<Result<Vec<_>>>()
+}
+
+async fn compute_window_aggregate(
+    schema: SchemaRef,
+    window_expr: Vec<Arc<dyn WindowExpr>>,
+    mut input: SendableRecordBatchStream,
+) -> ArrowResult<RecordBatch> {
+    let mut window_accumulators = create_window_accumulators(&window_expr)
+        .map_err(DataFusionError::into_arrow_external_error)?;
+
+    let expressions = window_expressions(&window_expr)
+        .map_err(DataFusionError::into_arrow_external_error)?;
+
+    let expressions = Arc::new(expressions);
+
+    // TODO each element shall have some size hint
+    let mut accumulator: Vec<Vec<ScalarValue>> =
+        iter::repeat(vec![]).take(window_expr.len()).collect();
+
+    let mut original_batches: Vec<RecordBatch> = vec![];
+
+    let mut total_num_rows = 0;
+
+    while let Some(batch) = input.next().await {
+        let batch = batch?;
+        total_num_rows += batch.num_rows();
+        original_batches.push(batch.clone());
+
+        let batch_aggregated =
+            window_aggregate_batch(&batch, &mut window_accumulators, 
&expressions)
+                .map_err(DataFusionError::into_arrow_external_error)?;
+        accumulator.iter_mut().zip(batch_aggregated).for_each(
+            |(acc_for_window, window_batch)| {
+                if let Some(data) = window_batch {
+                    acc_for_window.extend(data);
+                }
+            },
+        );
+    }
+
+    let aggregated_mapped = finalize_window_aggregation(&window_accumulators)
+        .map_err(DataFusionError::into_arrow_external_error)?;
+
+    let mut columns: Vec<ArrayRef> = accumulator
+        .iter()
+        .zip(aggregated_mapped)
+        .map(|(acc, agg)| {
+            let arr: ArrayRef = match (acc, agg) {
+                (acc, Some(scalar_value)) if acc.is_empty() => {
+                    scalar_value.to_array_of_size(total_num_rows)
+                }
+                (acc, None) if !acc.is_empty() => {
+                    return Err(DataFusionError::NotImplemented(
+                        "built in window function not yet 
implemented".to_owned(),
+                    ))
+                }
+                _ => {
+                    return Err(DataFusionError::Execution(
+                        "invalid window function behavior".to_owned(),
+                    ))
+                }
+            };
+            Ok(arr)
+        })
+        .collect::<Result<Vec<ArrayRef>>>()
+        .map_err(DataFusionError::into_arrow_external_error)?;
+
+    for i in 0..(schema.fields().len() - window_expr.len()) {
+        let col = concat(
+            &original_batches
+                .iter()
+                .map(|batch| batch.column(i).as_ref())
+                .collect::<Vec<_>>(),
+        )?;
+        columns.push(col);
     }
+
+    RecordBatch::try_new(schema.clone(), columns)
+}
+
+impl WindowAggStream {
+    /// Create a new WindowAggStream
+    pub fn new(
+        schema: SchemaRef,
+        window_expr: Vec<Arc<dyn WindowExpr>>,
+        input: SendableRecordBatchStream,
+    ) -> Self {
+        let (tx, rx) = futures::channel::oneshot::channel();
+        let schema_clone = schema.clone();
+        tokio::spawn(async move {
+            let result = compute_window_aggregate(schema_clone, window_expr, 
input).await;
+            tx.send(result)
+        });
+
+        Self {
+            output: rx,
+            finished: false,
+            schema,
+        }
+    }
+}
+
+impl Stream for WindowAggStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Option<Self::Item>> {
+        if self.finished {
+            return Poll::Ready(None);
+        }
+
+        // is the output ready?
+        let this = self.project();
+        let output_poll = this.output.poll(cx);
+
+        match output_poll {
+            Poll::Ready(result) => {
+                *this.finished = true;
+                // check for error in receiving channel and unwrap actual 
result
+                let result = match result {
+                    Err(e) => 
Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving
+                    Ok(result) => Some(result),
+                };
+                Poll::Ready(result)
+            }
+            Poll::Pending => Poll::Pending,
+        }
+    }
+}
+
+impl RecordBatchStream for WindowAggStream {
+    /// Get the schema
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    // use super::*;

Review comment:
       also there's integration tests in `sql.rs`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to