westonpace commented on code in PR #35514:
URL: https://github.com/apache/arrow/pull/35514#discussion_r1222497897


##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -15,15 +15,19 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#include "arrow/python/udf.h"
+#include <iostream>

Review Comment:
   ```suggestion
   ```



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());

Review Comment:
   This increment seems redundant given you already have one 
[here](https://github.com/apache/arrow/pull/35514/files#diff-d269c07da082f1b2d9b8f8628effd64aff74c163e325d253c2246410e0d1c3d0R381).



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {

Review Comment:
   ```suggestion
         : agg_cb(std::move(agg_cb)), agg_function(agg_function), 
output_type(std::move(output_type)) {
   ```
   Minor nit



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(std::move(field("", input_types[i])));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfScalarAggregatorImpl() {
+    if (_Py_IsFinalizing()) {
+      agg_function->detach();
+    }
+  }
+
+  Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {

Review Comment:
   ```suggestion
     Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& 
batch) override {
   ```



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(std::move(field("", input_types[i])));

Review Comment:
   ```suggestion
         fields.push_back(field("", input_types[i]));
   ```
   No need to call `std::move` on something that doesn't have a name.



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(std::move(field("", input_types[i])));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfScalarAggregatorImpl() {

Review Comment:
   ```suggestion
     ~PythonUdfScalarAggregatorImpl() override {
   ```



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(std::move(field("", input_types[i])));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfScalarAggregatorImpl() {
+    if (_Py_IsFinalizing()) {
+      agg_function->detach();
+    }
+  }
+
+  Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {
+    ARROW_ASSIGN_OR_RAISE(
+        auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, 
ctx->memory_pool()));
+    values.push_back(std::move(rb));
+    return Status::OK();
+  }
+
+  Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) {
+    auto& other_values = 
checked_cast<PythonUdfScalarAggregatorImpl&>(src).values;
+    values.insert(values.end(), std::make_move_iterator(other_values.begin()),
+                  std::make_move_iterator(other_values.end()));
+
+    other_values.erase(other_values.begin(), other_values.end());
+    return Status::OK();
+  }
+
+  Status Finalize(compute::KernelContext* ctx, Datum* out) {

Review Comment:
   ```suggestion
     Status Finalize(compute::KernelContext* ctx, Datum* out) override {
   ```
   
   These `override` suggestions are pretty minor but can be nice.  It just 
makes it a little easier to track down what went wrong if the parent class' 
signature changes.



##########
python/pyarrow/conftest.py:
##########
@@ -278,3 +278,59 @@ def unary_function(ctx, x):
                                 {"array": pa.int64()},
                                 pa.int64())
     return unary_function, func_name
+
+
[email protected](scope="session")
+def unary_agg_func_fixture():
+    """
+    Register a unary aggregate function
+    """
+    from pyarrow import compute as pc
+    import numpy as np
+
+    def func(ctx, x):
+        return pa.scalar(np.nanmean(x))
+
+    func_name = "y=avg(x)"
+    func_doc = {"summary": "y=avg(x)",
+                "description": "find mean of x"}
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.float64(),
+                                   },
+                                   pa.float64()
+                                   )
+    return func, func_name
+
+
[email protected](scope="session")
+def varargs_agg_func_fixture():
+    """
+    Register a unary aggregate function
+    """
+    from pyarrow import compute as pc
+    import numpy as np
+
+    def func(ctx, *args):
+        sum = 0.0
+        for arg in args:
+            sum += np.nanmean(arg)
+        return pa.scalar(sum)
+
+    func_name = "y=sum_mean(x...)"
+    func_doc = {"summary": "Varargs aggregate",
+                "description": "Varargs aggregate"}
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.int64(),
+                                       "y": pa.float64()

Review Comment:
   Ok, so the test case is verifying that the python function can take in 
`*args` if needed (even though it still lists the args when registering)?



##########
python/pyarrow/_compute.pyx:
##########
@@ -2665,11 +2665,19 @@ cdef get_register_tabular_function():
     return reg
 
 
+cdef get_register_aggregate_function():
+    cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+    reg.register_func = RegisterAggregateFunction
+    return reg
+
+
 def register_scalar_function(func, function_name, function_doc, in_types, 
out_type,
                              func_registry=None):
     """
     Register a user-defined scalar function.
 
+    This API is EXPERIMENTAL.

Review Comment:
   That's fine.  UDFs are also marked experimental in the [pyarrow 
docs](https://github.com/apache/arrow/blob/407423394cb988669a0b5716899810a924c0aa4c/docs/source/python/compute.rst?plain=1#L390).



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -101,6 +125,103 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> 
input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(std::move(field("", input_types[i])));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfScalarAggregatorImpl() {
+    if (_Py_IsFinalizing()) {
+      agg_function->detach();
+    }
+  }
+
+  Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {
+    ARROW_ASSIGN_OR_RAISE(
+        auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, 
ctx->memory_pool()));
+    values.push_back(std::move(rb));
+    return Status::OK();
+  }
+
+  Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) {

Review Comment:
   ```suggestion
     Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) 
override {
   ```



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -65,6 +69,26 @@ struct PythonUdfKernelInit {
   std::shared_ptr<OwnedRefNoGIL> function;
 };
 
+struct ScalarUdfAggregator : public compute::KernelState {
+  virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& 
batch) = 0;
+  virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& 
src) = 0;
+  virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
+};
+
+arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const 
compute::ExecSpan& batch) {
+  return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, 
compute::KernelState&& src,
+                                compute::KernelState* dst) {
+  return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx, 
std::move(src));

Review Comment:
   Yep.  The `move` is indeed needed.  C++ is confusing :cold_sweat: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to