westonpace commented on code in PR #14527:
URL: https://github.com/apache/arrow/pull/14527#discussion_r1023541046
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
Review Comment:
I'm not much of a python expert but getters and setters seem like overkill
here. Are they needed?
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
Review Comment:
```suggestion
return State(0)
```
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null)
+
+ def merge(ctx, other_state):
+ merged_state_val = ctx.state.non_null + other_state.non_null
+ return State(merged_state_val)
+
+ def finalize(ctx):
+ return pa.array([ctx.state.non_null])
+
+ func_name = "simple_count"
+ unary_doc = {"summary": "count function",
+ "description": "test agg count function"}
+
+ pc.register_scalar_aggregate_function(init,
+ consume,
+ merge,
+ finalize,
+ func_name,
+ unary_doc,
+ {"array": pa.int64()},
+ pa.int64())
+
+ assert pc.call_function(func_name, [pa.array(
+ [10, 20, None, 30, None, 40])]) == pa.array([4])
+
+
+def test_aggregate_udf_with_custom_state_multi_attr():
+ class State:
+ def __init__(self, non_null=0, null=0):
+ self._non_null = non_null
+ self._null = null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ @property
+ def null(self):
+ return self._null
+
+ @null.setter
+ def null(self, value):
+ self._null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "non_null: " + str(self.non_null) \
+ + ", null: " + str(self.null)
+
+ def init():
+ state = State(0, 0)
+ return state
+
+ def consume(ctx, x):
+ null = 0
+ non_null = 0
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ null = len(x) - non_null
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ else:
+ null = 1
+ non_null = non_null + ctx.state.non_null
Review Comment:
```suggestion
non_null = non_null + ctx.state.non_null
null = null + ctx.state.null
```
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null)
+
+ def merge(ctx, other_state):
+ merged_state_val = ctx.state.non_null + other_state.non_null
+ return State(merged_state_val)
+
+ def finalize(ctx):
+ return pa.array([ctx.state.non_null])
+
+ func_name = "simple_count"
+ unary_doc = {"summary": "count function",
+ "description": "test agg count function"}
+
+ pc.register_scalar_aggregate_function(init,
+ consume,
+ merge,
+ finalize,
+ func_name,
+ unary_doc,
+ {"array": pa.int64()},
+ pa.int64())
+
+ assert pc.call_function(func_name, [pa.array(
+ [10, 20, None, 30, None, 40])]) == pa.array([4])
+
+
+def test_aggregate_udf_with_custom_state_multi_attr():
+ class State:
+ def __init__(self, non_null=0, null=0):
+ self._non_null = non_null
+ self._null = null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ @property
+ def null(self):
+ return self._null
+
+ @null.setter
+ def null(self, value):
+ self._null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "non_null: " + str(self.non_null) \
+ + ", null: " + str(self.null)
+
+ def init():
+ state = State(0, 0)
+ return state
+
+ def consume(ctx, x):
+ null = 0
+ non_null = 0
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ null = len(x) - non_null
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ else:
+ null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null, null)
+
+ def merge(ctx, other_state):
+ merged_st_non_null = ctx.state.non_null + other_state.non_null
+ merged_st_null = ctx.state.null + other_state.null
+ return State(merged_st_non_null, merged_st_null)
+
+ def finalize(ctx):
+ print(ctx.state)
+ return pa.array([ctx.state.non_null, ctx.state.null])
Review Comment:
An aggregate UDF must always return an array of length 1. Do you perhaps
want something like...
```
return pa.array([{"non_null": ctx.state.non_null, "null": ctx.state.null}])
```
##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -120,6 +128,218 @@ Status RegisterScalarFunction(PyObject* user_function,
ScalarUdfWrapperCallback
return Status::OK();
}
+// Scalar Aggregate Functions
+
+struct ScalarUdfAggregator : public compute::KernelState {
+ virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan&
batch) = 0;
+ virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&&
src) = 0;
+ virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
+};
+
+arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const
compute::ExecSpan& batch) {
+ return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+arrow::Status AggregateUdfMerge(compute::KernelContext* ctx,
compute::KernelState&& src,
+ compute::KernelState* dst) {
+ return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx,
std::move(src));
+}
+
+arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum*
out) {
+ return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
+}
+
+ScalarAggregateUdfContext::~ScalarAggregateUdfContext() {
+ if (_Py_IsFinalizing()) {
+ Py_DECREF(this->state);
+ }
+}
+
+struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator {
+
+ ScalarAggregateInitUdfWrapperCallback init_cb;
+ ScalarAggregateConsumeUdfWrapperCallback consume_cb;
+ ScalarAggregateMergeUdfWrapperCallback merge_cb;
+ ScalarAggregateFinalizeUdfWrapperCallback finalize_cb;
+ std::shared_ptr<OwnedRefNoGIL> init_function;
+ std::shared_ptr<OwnedRefNoGIL> consume_function;
+ std::shared_ptr<OwnedRefNoGIL> merge_function;
+ std::shared_ptr<OwnedRefNoGIL> finalize_function;
+ std::shared_ptr<DataType> output_type;
+
+
+ PythonScalarUdfAggregatorImpl(ScalarAggregateInitUdfWrapperCallback init_cb,
+ ScalarAggregateConsumeUdfWrapperCallback consume_cb,
+ ScalarAggregateMergeUdfWrapperCallback merge_cb,
+ ScalarAggregateFinalizeUdfWrapperCallback finalize_cb,
+ std::shared_ptr<OwnedRefNoGIL> init_function,
+ std::shared_ptr<OwnedRefNoGIL> consume_function,
+ std::shared_ptr<OwnedRefNoGIL> merge_function,
+ std::shared_ptr<OwnedRefNoGIL> finalize_function,
+ const std::shared_ptr<DataType>& output_type) : init_cb(init_cb),
+ consume_cb(consume_cb),
+ merge_cb(merge_cb),
+ finalize_cb(finalize_cb),
+ init_function(init_function),
+ consume_function(consume_function),
+ merge_function(merge_function),
+ finalize_function(finalize_function),
+ output_type(output_type) {
+ Init(init_cb, init_function);
+ }
+
+ ~PythonScalarUdfAggregatorImpl() {
+ if (_Py_IsFinalizing()) {
+ init_function->detach();
+ consume_function->detach();
+ merge_function->detach();
+ finalize_function->detach();
+ }
+ }
+
+ void Init(ScalarAggregateInitUdfWrapperCallback& init_cb ,
std::shared_ptr<OwnedRefNoGIL>& init_function) {
+ auto st = SafeCallIntoPython([&]() -> Status {
+ OwnedRef result(init_cb(init_function->obj()));
+ PyObject* init_res = result.obj();
+ Py_INCREF(init_res);
+ this->udf_context_ = ScalarAggregateUdfContext{default_memory_pool(),
std::move(init_res)};
+ this->owned_state_.reset(result.obj());
+ RETURN_NOT_OK(CheckPyError());
+ return Status::OK();
+ });
+ if (!st.ok()) {
+ throw std::runtime_error(st.ToString());
Review Comment:
Is throwing an error the right thing to do?
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
+ and returns the state object.
+ consume_func : callable
+ A callable implementing the user-defined consume function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined.
+ To define a varargs function, pass a callable that takes
+ varargs. The last in_type will be the type of all varargs
+ arguments.
Review Comment:
How do varargs work? Do I define a `*args`?
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
Review Comment:
Can a unary aggregate ever be called with a scalar?
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
+ and returns the state object.
+ consume_func : callable
+ A callable implementing the user-defined consume function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined.
+ To define a varargs function, pass a callable that takes
+ varargs. The last in_type will be the type of all varargs
+ arguments.
+
+ This function returns the updated state after consuming the
+ received data.
+ merge_func: callable
Review Comment:
The concept of a "merge function" is not going to be obvious to most users.
It is very possible for an engine to be defined that does not have to worry
about the concept of a merge. I think we need to describe some background here
for why a merge is needed in the first place. Something like:
Aggregates may be calculated across many threads in parallel. Each thread
will call the init function once to generate a state for that thread. Once all
values have been consumed then the threads from each state will be merged
together to get the final result state. The merge function should take two
states and combine them.
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
+ and returns the state object.
+ consume_func : callable
+ A callable implementing the user-defined consume function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined.
+ To define a varargs function, pass a callable that takes
+ varargs. The last in_type will be the type of all varargs
+ arguments.
+
+ This function returns the updated state after consuming the
+ received data.
+ merge_func: callable
+ A callable implementing the user-defined merge function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, the second argument it takes is an state object.
+ This object holds the state with which the current state
+ must be merged with. The current state can be retrieved from
+ the context object which can be acessed by `context.state`.
+ The state doesn't need to be set in the Python side and it is
+ autonomously handled in the C++ backend. The updated state must
+ be returned from this function.
+ finalize_func: callable
+ A callable implementing the user-defined finalize function.
+ The first argument is the context argument of type
+ ScalarUdfContext.
+ Using the context argument the state can be extracted and return
Review Comment:
"the state can be extracted" is not very obvious. Maybe something like:
The purpose of the finalize function is to transform the state (which is
available in the context argument) into an array. This array will be the final
result of the aggregation.
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null)
+
+ def merge(ctx, other_state):
+ merged_state_val = ctx.state.non_null + other_state.non_null
+ return State(merged_state_val)
+
+ def finalize(ctx):
+ return pa.array([ctx.state.non_null])
+
+ func_name = "simple_count"
Review Comment:
Maybe `valid_count` or `non_null_count`?
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
Review Comment:
```suggestion
A scalar aggregate function is a set of 4 functions which are
called at different times during the calculation of the aggregate.
```
I think we are missing a basic description first. Something like:
```
An aggregate function reduces a column of values into a single aggregate
result.
```
##########
cpp/examples/arrow/udf_example.cc:
##########
@@ -83,15 +83,154 @@ arrow::Status Execute() {
ARROW_ASSIGN_OR_RAISE(auto res, cp::CallFunction(name, {x, y, z}));
auto res_array = res.make_array();
- std::cout << "Result" << std::endl;
+ std::cout << "Scalar UDF Result" << std::endl;
std::cout << res_array->ToString() << std::endl;
return arrow::Status::OK();
}
+// User-defined Scalar Aggregate Function Example
+struct ScalarUdfAggregator : public cp::KernelState {
+ virtual arrow::Status Consume(cp::KernelContext* ctx, const cp::ExecSpan&
batch) = 0;
+ virtual arrow::Status MergeFrom(cp::KernelContext* ctx, cp::KernelState&&
src) = 0;
+ virtual arrow::Status Finalize(cp::KernelContext* ctx, arrow::Datum* out) =
0;
+};
+
+class SimpleCountFunctionOptionsType : public cp::FunctionOptionsType {
+ const char* type_name() const override { return
"SimpleCountFunctionOptionsType"; }
+ std::string Stringify(const cp::FunctionOptions&) const override {
+ return "SimpleCountFunctionOptionsType";
+ }
+ bool Compare(const cp::FunctionOptions&, const cp::FunctionOptions&) const
override {
+ return true;
+ }
+ std::unique_ptr<cp::FunctionOptions> Copy(const cp::FunctionOptions&) const
override;
+};
+
+cp::FunctionOptionsType* GetSimpleCountFunctionOptionsType() {
+ static SimpleCountFunctionOptionsType options_type;
+ return &options_type;
+}
+
+class SimpleCountOptions : public cp::FunctionOptions {
+ public:
+ SimpleCountOptions() :
cp::FunctionOptions(GetSimpleCountFunctionOptionsType()) {}
+ static constexpr char const kTypeName[] = "SimpleCountOptions";
+ static SimpleCountOptions Defaults() { return SimpleCountOptions{}; }
+};
+
+std::unique_ptr<cp::FunctionOptions> SimpleCountFunctionOptionsType::Copy(
+ const cp::FunctionOptions&) const {
+ return std::make_unique<SimpleCountOptions>();
+}
+
+const cp::FunctionDoc simple_count_doc{
+ "SimpleCount the number of null / non-null values",
+ ("By default, only non-null values are counted.\n"
+ "This can be changed through SimpleCountOptions."),
Review Comment:
Well, this isn't true quite yet :)
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null)
+
+ def merge(ctx, other_state):
+ merged_state_val = ctx.state.non_null + other_state.non_null
+ return State(merged_state_val)
+
+ def finalize(ctx):
+ return pa.array([ctx.state.non_null])
+
+ func_name = "simple_count"
+ unary_doc = {"summary": "count function",
+ "description": "test agg count function"}
+
+ pc.register_scalar_aggregate_function(init,
+ consume,
+ merge,
+ finalize,
+ func_name,
+ unary_doc,
+ {"array": pa.int64()},
+ pa.int64())
+
+ assert pc.call_function(func_name, [pa.array(
+ [10, 20, None, 30, None, 40])]) == pa.array([4])
+
+
+def test_aggregate_udf_with_custom_state_multi_attr():
+ class State:
+ def __init__(self, non_null=0, null=0):
+ self._non_null = non_null
+ self._null = null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ @property
+ def null(self):
+ return self._null
+
+ @null.setter
+ def null(self, value):
+ self._null = value
Review Comment:
Same comment on getters and setters
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
Review Comment:
Is `callable` the right word for python? In python docs it refers to this
kind of thing as `function` I think:
https://docs.python.org/3/library/functions.html#map
##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def test_aggregate_udf_with_custom_state():
+ class State:
+ def __init__(self, non_null=0):
+ self._non_null = non_null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "count: " + str(self.non_null)
+
+ def init():
+ state = State(0)
+ return state
+
+ def consume(ctx, x):
+ if isinstance(x, pa.Array):
+ non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py()
+ elif isinstance(x, pa.Scalar):
+ if x.as_py():
+ non_null = 1
+ non_null = non_null + ctx.state.non_null
+ return State(non_null)
+
+ def merge(ctx, other_state):
+ merged_state_val = ctx.state.non_null + other_state.non_null
+ return State(merged_state_val)
+
+ def finalize(ctx):
+ return pa.array([ctx.state.non_null])
+
+ func_name = "simple_count"
+ unary_doc = {"summary": "count function",
+ "description": "test agg count function"}
+
+ pc.register_scalar_aggregate_function(init,
+ consume,
+ merge,
+ finalize,
+ func_name,
+ unary_doc,
+ {"array": pa.int64()},
+ pa.int64())
+
+ assert pc.call_function(func_name, [pa.array(
+ [10, 20, None, 30, None, 40])]) == pa.array([4])
+
+
+def test_aggregate_udf_with_custom_state_multi_attr():
+ class State:
+ def __init__(self, non_null=0, null=0):
+ self._non_null = non_null
+ self._null = null
+
+ @property
+ def non_null(self):
+ return self._non_null
+
+ @non_null.setter
+ def non_null(self, value):
+ self._non_null = value
+
+ @property
+ def null(self):
+ return self._null
+
+ @null.setter
+ def null(self, value):
+ self._null = value
+
+ def __repr__(self):
+ if self._non_null is None:
+ return "no values stored"
+ else:
+ return "non_null: " + str(self.non_null) \
+ + ", null: " + str(self.null)
+
+ def init():
+ state = State(0, 0)
+ return state
Review Comment:
```suggestion
return State(0, 0)
```
##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -120,6 +128,218 @@ Status RegisterScalarFunction(PyObject* user_function,
ScalarUdfWrapperCallback
return Status::OK();
}
+// Scalar Aggregate Functions
+
+struct ScalarUdfAggregator : public compute::KernelState {
+ virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan&
batch) = 0;
+ virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&&
src) = 0;
+ virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
+};
+
+arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const
compute::ExecSpan& batch) {
+ return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+arrow::Status AggregateUdfMerge(compute::KernelContext* ctx,
compute::KernelState&& src,
+ compute::KernelState* dst) {
+ return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx,
std::move(src));
+}
+
+arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum*
out) {
+ return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
+}
+
+ScalarAggregateUdfContext::~ScalarAggregateUdfContext() {
+ if (_Py_IsFinalizing()) {
+ Py_DECREF(this->state);
+ }
+}
+
+struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator {
+
+ ScalarAggregateInitUdfWrapperCallback init_cb;
+ ScalarAggregateConsumeUdfWrapperCallback consume_cb;
+ ScalarAggregateMergeUdfWrapperCallback merge_cb;
+ ScalarAggregateFinalizeUdfWrapperCallback finalize_cb;
+ std::shared_ptr<OwnedRefNoGIL> init_function;
+ std::shared_ptr<OwnedRefNoGIL> consume_function;
+ std::shared_ptr<OwnedRefNoGIL> merge_function;
+ std::shared_ptr<OwnedRefNoGIL> finalize_function;
+ std::shared_ptr<DataType> output_type;
+
+
+ PythonScalarUdfAggregatorImpl(ScalarAggregateInitUdfWrapperCallback init_cb,
+ ScalarAggregateConsumeUdfWrapperCallback consume_cb,
+ ScalarAggregateMergeUdfWrapperCallback merge_cb,
+ ScalarAggregateFinalizeUdfWrapperCallback finalize_cb,
+ std::shared_ptr<OwnedRefNoGIL> init_function,
+ std::shared_ptr<OwnedRefNoGIL> consume_function,
+ std::shared_ptr<OwnedRefNoGIL> merge_function,
+ std::shared_ptr<OwnedRefNoGIL> finalize_function,
+ const std::shared_ptr<DataType>& output_type) : init_cb(init_cb),
+ consume_cb(consume_cb),
+ merge_cb(merge_cb),
+ finalize_cb(finalize_cb),
+ init_function(init_function),
+ consume_function(consume_function),
+ merge_function(merge_function),
+ finalize_function(finalize_function),
+ output_type(output_type) {
+ Init(init_cb, init_function);
+ }
+
+ ~PythonScalarUdfAggregatorImpl() {
+ if (_Py_IsFinalizing()) {
+ init_function->detach();
+ consume_function->detach();
+ merge_function->detach();
+ finalize_function->detach();
+ }
+ }
+
+ void Init(ScalarAggregateInitUdfWrapperCallback& init_cb ,
std::shared_ptr<OwnedRefNoGIL>& init_function) {
+ auto st = SafeCallIntoPython([&]() -> Status {
+ OwnedRef result(init_cb(init_function->obj()));
+ PyObject* init_res = result.obj();
Review Comment:
Probably need to do sanity checking (e.g. not null, etc.) on things returned
from user callbacks.
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
Review Comment:
I'm not sure that "set the state" is correct. Maybe "create the initial
state"?
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
Review Comment:
Why "scalar aggregate"? Is there a non-scalar aggregate?
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
+ and returns the state object.
+ consume_func : callable
+ A callable implementing the user-defined consume function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined.
+ To define a varargs function, pass a callable that takes
+ varargs. The last in_type will be the type of all varargs
+ arguments.
+
+ This function returns the updated state after consuming the
+ received data.
+ merge_func: callable
+ A callable implementing the user-defined merge function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, the second argument it takes is an state object.
+ This object holds the state with which the current state
+ must be merged with. The current state can be retrieved from
+ the context object which can be acessed by `context.state`.
+ The state doesn't need to be set in the Python side and it is
+ autonomously handled in the C++ backend. The updated state must
Review Comment:
I'm not sure I understand the sentence that starts with "The state doesn't
need to be set..."
##########
python/pyarrow/_compute.pyx:
##########
@@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]>
&_scalar_udf_callback, c_options))
+
+
+def register_scalar_aggregate_function(init_func, consume_func, merge_func,
finalize_func,
+ function_name, function_doc, in_types,
out_type):
+ """
+ Register a user-defined scalar aggregate function.
+
+ A scalar aggregate function is a set of 4 functions which formulates
+ the operation pieces of an scalar aggregation. The base behavior in
+ terms of computation is very much similar to scalar functions.
+
+ Parameters
+ ----------
+ init_func : callable
+ A callable implementing the user-defined initialization function.
+ This function is used to set the state for the aggregate operation
+ and returns the state object.
+ consume_func : callable
+ A callable implementing the user-defined consume function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined.
+ To define a varargs function, pass a callable that takes
+ varargs. The last in_type will be the type of all varargs
+ arguments.
+
+ This function returns the updated state after consuming the
+ received data.
+ merge_func: callable
+ A callable implementing the user-defined merge function.
+ The first argument is the context argument of type
+ ScalarAggregateUdfContext.
+ Then, the second argument it takes is an state object.
+ This object holds the state with which the current state
+ must be merged with. The current state can be retrieved from
+ the context object which can be acessed by `context.state`.
+ The state doesn't need to be set in the Python side and it is
+ autonomously handled in the C++ backend. The updated state must
+ be returned from this function.
+ finalize_func: callable
+ A callable implementing the user-defined finalize function.
+ The first argument is the context argument of type
+ ScalarUdfContext.
+ Using the context argument the state can be extracted and return
+ type must be an array matching the `out_type`.
Review Comment:
In your C++ example the return type is a scalar?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]