bkmgit commented on a change in pull request #11882: URL: https://github.com/apache/arrow/pull/11882#discussion_r771991046
########## File path: cpp/src/arrow/compute/kernels/codegen_internal.h ########## @@ -982,6 +1021,407 @@ template <typename OutType, typename ArgType, typename Op> using ScalarBinaryNotNullStatefulEqualTypes = ScalarBinaryNotNullStateful<OutType, ArgType, ArgType, Op>; +// A kernel exec generator for ternary functions that addresses both array and +// scalar inputs and dispatches input iteration and output writing to other +// templates +// +// This template executes the operator even on the data behind null values, +// therefore it is generally only suitable for operators that are safe to apply +// even on the null slot values. +// +// The "Op" functor should have the form +// +// struct Op { +// template <typename OutValue, typename Arg0Value, typename Arg1Value, typename +// Arg2Value> static OutValue Call(KernelContext* ctx, Arg0Value arg0, Arg1Value arg1, +// Arg2Value arg2, Status *st) { +// // implementation +// // NOTE: "status" should only be populated with errors, +// // leave it unmodified to indicate Status::OK() +// } +// }; +template <typename OutType, typename Arg0Type, typename Arg1Type, typename Arg2Type, + typename Op> +struct ScalarTernary { + using OutValue = typename GetOutputType<OutType>::T; + using Arg0Value = typename GetViewType<Arg0Type>::T; + using Arg1Value = typename GetViewType<Arg1Type>::T; + using Arg2Value = typename GetViewType<Arg2Type>::T; + + static Status ArrayArrayArray(KernelContext* ctx, const ArrayData& arg0, + const ArrayData& arg1, const ArrayData& arg2, + Datum* out) { + Status st = Status::OK(); + ArrayIterator<Arg0Type> arg0_it(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); + ArrayIterator<Arg2Type> arg2_it(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_it(), arg1_it(), arg2_it(), &st); + })); + return st; + } + + static Status ArrayArrayScalar(KernelContext* ctx, const ArrayData& arg0, + const ArrayData& arg1, const Scalar& arg2, Datum* out) { + Status st = Status::OK(); + ArrayIterator<Arg0Type> arg0_it(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); + auto arg2_val = UnboxScalar<Arg2Type>::Unbox(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_it(), arg1_it(), arg2_val, &st); + })); + return st; + } + + static Status ArrayScalarArray(KernelContext* ctx, const ArrayData& arg0, + const Scalar& arg1, const ArrayData& arg2, Datum* out) { + Status st = Status::OK(); + ArrayIterator<Arg0Type> arg0_it(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + ArrayIterator<Arg2Type> arg2_it(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_it(), arg1_val, arg2_it(), &st); + })); + return st; + } + + static Status ScalarArrayArray(KernelContext* ctx, const Scalar& arg0, + const ArrayData& arg1, const ArrayData& arg2, + Datum* out) { + Status st = Status::OK(); + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); + ArrayIterator<Arg2Type> arg2_it(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_val, arg1_it(), arg2_it(), &st); + })); + return st; + } + + static Status ArrayScalarScalar(KernelContext* ctx, const ArrayData& arg0, + const Scalar& arg1, const Scalar& arg2, Datum* out) { + Status st = Status::OK(); + ArrayIterator<Arg0Type> arg0_it(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + auto arg2_val = UnboxScalar<Arg2Type>::Unbox(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_it(), arg1_val, arg2_val, &st); + })); + return st; + } + + static Status ScalarScalarArray(KernelContext* ctx, const Scalar& arg0, + const Scalar& arg1, const ArrayData& arg2, Datum* out) { + Status st = Status::OK(); + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + ArrayIterator<Arg2Type> arg2_it(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_val, arg1_val, arg2_it(), &st); + })); + return st; + } + + static Status ScalarArrayScalar(KernelContext* ctx, const Scalar& arg0, + const ArrayData& arg1, const Scalar& arg2, Datum* out) { + Status st = Status::OK(); + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + ArrayIterator<Arg1Type> arg1_it(arg1); + auto arg2_val = UnboxScalar<Arg2Type>::Unbox(arg2); + RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue { + return Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_val, arg1_it(), arg2_val, &st); + })); + return st; + } + + static Status ScalarScalarScalar(KernelContext* ctx, const Scalar& arg0, + const Scalar& arg1, const Scalar& arg2, Datum* out) { + Status st = Status::OK(); + if (out->scalar()->is_valid) { + auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0); + auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1); + auto arg2_val = UnboxScalar<Arg2Type>::Unbox(arg2); + BoxScalar<OutType>::Box( + Op::template Call<OutValue, Arg0Value, Arg1Value, Arg2Value>( + ctx, arg0_val, arg1_val, arg2_val, &st), + out->scalar().get()); + } + return st; + } + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { + return ArrayArrayArray(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out); + } else { + return ArrayArrayScalar(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out); + } + } else { + if (batch[2].kind() == Datum::ARRAY) { + return ArrayScalarArray(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].array(), out); + } else { + return ArrayScalarScalar(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].scalar(), out); + } + } + } else { + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { + return ScalarArrayArray(ctx, *batch[0].scalar(), *batch[1].array(), + *batch[2].array(), out); + } else { + return ScalarArrayScalar(ctx, *batch[0].scalar(), *batch[1].array(), + *batch[2].scalar(), out); + } + } else { + if (batch[2].kind() == Datum::ARRAY) { + return ScalarScalarArray(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].array(), out); + } else { + return ScalarScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].scalar(), out); + } + } + } + } +}; + +// An alternative to ScalarTernary that Applies a scalar operation with state on +// only the value pairs which are not-null in both arrays +template <typename OutType, typename Arg0Type, typename Arg1Type, typename Arg2Type, + typename Op> +struct ScalarTernaryNotNullStateful { Review comment: Decimal support added. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org