felipecrv commented on code in PR #15083: URL: https://github.com/apache/arrow/pull/15083#discussion_r1064883321
########## cpp/src/arrow/compute/kernels/hash_aggregate.cc: ########## @@ -108,19 +108,31 @@ Result<TypeHolder> ResolveGroupOutputType(KernelContext* ctx, return checked_cast<GroupedAggregator*>(ctx->state())->out_type(); } -HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) { +HashAggregateKernel MakeKernel(std::shared_ptr<KernelSignature> signature, + KernelInit init) { HashAggregateKernel kernel; kernel.init = std::move(init); - kernel.signature = - KernelSignature::Make({std::move(argument_type), InputType(Type::UINT32)}, - OutputType(ResolveGroupOutputType)); + kernel.signature = std::move(signature); kernel.resize = HashAggregateResize; kernel.consume = HashAggregateConsume; kernel.merge = HashAggregateMerge; kernel.finalize = HashAggregateFinalize; return kernel; } +HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) { + return MakeKernel( + KernelSignature::Make({std::move(argument_type), InputType(Type::UINT32)}, + OutputType(ResolveGroupOutputType)), + std::move(init)); +} + +HashAggregateKernel MakeUnaryKernel(KernelInit init) { + return MakeKernel(KernelSignature::Make({InputType(Type::UINT32)}, + OutputType(ResolveGroupOutputType)), + std::move(init)); +} + Review Comment: Very confusing. I decided to refer to the the confusing arity in the name because the `Arity` is mentioned in the context where `Make*Kernel` is called: ```cpp { auto func = std::make_shared<HashAggregateFunction>( "hash_count", Arity::Binary(), hash_count_doc, &default_count_options); DCHECK_OK(func->AddKernel( MakeKernel(InputType::Any(), HashAggregateInit<GroupedCountImpl>))); DCHECK_OK(registry->AddFunction(std::move(func))); } { auto func = std::make_shared<HashAggregateFunction>( "hash_count_all", Arity::Unary(), hash_count_all_doc, &default_count_all_options); DCHECK_OK(func->AddKernel(MakeUnaryKernel(HashAggregateInit<GroupedCountAllImpl>))); auto status = registry->AddFunction(std::move(func)); DCHECK_OK(status); } ``` As I understood it, there is the arity of the kernels and the arity of the aggregation functions. The former needs to have arity+1 to support aggregations combined with group-by. ########## cpp/src/arrow/compute/exec/plan_test.cc: ########## @@ -1298,17 +1305,53 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { }) .AddToPlan(plan.get())); - ASSERT_THAT( - StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray({ - ExecBatchFromJSON( - {boolean(), boolean(), int64(), float64(), int64(), float64(), int64(), - float64(), float64()}, - {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, - ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY, - ArgShape::SCALAR}, - R"([[false, true, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"), - })))); + auto exec_batch = ExecBatchFromJSON( + {boolean(), boolean(), int64(), int64(), float64(), int64(), float64(), int64(), + float64(), float64()}, + {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, + ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, + ArgShape::ARRAY, ArgShape::SCALAR}, + R"([[false, true, 6, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray({ + std::move(exec_batch), + })))); +} + +TEST(ExecPlanExecution, ScalarSourceStandaloneNullaryScalarAggSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator<std::optional<ExecBatch>> sink_gen; + + BatchesWithSchema scalar_data; + scalar_data.batches = { + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::SCALAR, ArgShape::SCALAR}, + "[[5, null], [5, false], [5, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [null, false], [7, true]]")}; + scalar_data.schema = schema({ + field("a", int32()), + field("b", boolean()), + }); + + auto sequence = Declaration::Sequence({ + {"source", SourceNodeOptions{scalar_data.schema, scalar_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{ + {"count_all", "count(*)"}, + }}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }); + + // index can't be tested as it's order-dependent + // mode/quantile can't be tested as they're technically vector kernels Review Comment: Fixed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org