westonpace commented on code in PR #33775: URL: https://github.com/apache/arrow/pull/33775#discussion_r1083118332
########## cpp/src/arrow/compute/api_scalar.h: ########## @@ -882,6 +891,20 @@ ARROW_EXPORT Result<Datum> Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. Review Comment: Which argument? I assume the output is null if either argument is null? Can we be more explicit. ########## cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc: ########## @@ -113,7 +112,7 @@ class TestBaseUnaryRoundArithmetic : public ::testing::Test { // (Array, Array) void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg, const std::shared_ptr<Array>& expected) { - ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr)); + ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr)) Review Comment: While this does compile and work we try and add a `;` to the end of these kinds of macros anyways for readability ```suggestion ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr)); ``` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessUncheckedArithmetic( }; } +ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode( + const std::string& function_name) { + return [function_name](const SubstraitCall& call) -> Result<compute::Expression> { + ARROW_ASSIGN_OR_RAISE( + compute::RoundMode round_mode, + ParseOptionOrElse( + call, "rounding", kRoundModeParser, + {compute::RoundMode::DOWN, compute::RoundMode::UP, + compute::RoundMode::TOWARDS_ZERO, compute::RoundMode::TOWARDS_INFINITY, + compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP, + compute::RoundMode::HALF_TOWARDS_ZERO, + compute::RoundMode::HALF_TOWARDS_INFINITY, compute::RoundMode::HALF_TO_EVEN, + compute::RoundMode::HALF_TO_ODD}, + compute::RoundMode::HALF_TOWARDS_INFINITY)); + ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args, + GetValueArgs(call, 0)); + std::shared_ptr<compute::RoundBinaryOptions> options = Review Comment: Do we want to optimize and call the unary round if the second value is a scalar? If not in this PR can we create a follow-up github issue so we don't lose track of it? Or maybe round_binary itself can fallback to unary rounding if the second argument is scalar. ########## cpp/src/arrow/compute/api_scalar.h: ########## @@ -882,6 +891,20 @@ ARROW_EXPORT Result<Datum> Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg1 the value rounded Review Comment: ```suggestion /// \param[in] arg1 the value to be rounded ``` ########## cpp/src/arrow/compute/api_scalar.h: ########## @@ -882,6 +891,20 @@ ARROW_EXPORT Result<Datum> Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg1 the value rounded +/// \param[in] arg2 the number of significant digits to round to +/// \param[in] options rounding options (rounding mode and number of digits), optional Review Comment: ```suggestion /// \param[in] options rounding options (rounding mode), optional ``` Or just get rid of the parentheses section entirely. ########## cpp/src/arrow/compute/api_scalar.h: ########## @@ -882,6 +891,20 @@ ARROW_EXPORT Result<Datum> Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg1 the value rounded +/// \param[in] arg2 the number of significant digits to round to Review Comment: Can this be negative? Do we define elsewhere what that entails? ########## cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc: ########## @@ -18,7 +18,6 @@ #include <algorithm> #include <cmath> #include <memory> -#include <string> Review Comment: Our guideline for includes is [`iwyu`](https://include-what-you-use.org/). We don't always follow it perfectly (the conformance tool doesn't like type_fwd files) but it is what we aim for. Please don't remove includes if they are used in the file (I still see many instances of `std::string`) even if the file compiles otherwise (transitive includes are potentially unstable). ########## cpp/src/arrow/compute/kernels/scalar_round.cc: ########## @@ -751,60 +877,25 @@ ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id) } } -struct ArithmeticFunction : ScalarFunction { +struct RoundFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override { RETURN_NOT_OK(CheckArity(types->size())); - RETURN_NOT_OK(CheckDecimals(types)); - using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *types)) return kernel; EnsureDictionaryDecoded(types); - // Only promote types for binary functions - if (types->size() == 2) { - ReplaceNullWithOtherType(types); - TimeUnit::type finest_unit; - if (CommonTemporalResolution(types->data(), types->size(), &finest_unit)) { - ReplaceTemporalTypes(finest_unit, types); - } else { - if (TypeHolder type = CommonNumeric(*types)) { - ReplaceTypes(type, types); - } - } - } - if (auto kernel = DispatchExactImpl(this, *types)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *types); } - - Status CheckDecimals(std::vector<TypeHolder>* types) const { - if (!HasDecimal(*types)) return Status::OK(); - - if (types->size() == 2) { - // "add_checked" -> "add" - const auto func_name = name(); - const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { - return CastBinaryDecimalArgs(DecimalPromotion::kAdd, types); - } else if (op == "multiply") { - return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, types); - } else if (op == "divide") { - return CastBinaryDecimalArgs(DecimalPromotion::kDivide, types); - } else { - return Status::Invalid("Invalid decimal function: ", func_name); - } - } - return Status::OK(); - } }; -/// An ArithmeticFunction that promotes only decimal arguments to double. -struct ArithmeticDecimalToFloatingPointFunction : public ArithmeticFunction { - using ArithmeticFunction::ArithmeticFunction; +/// An RoundFunction that promotes only decimal arguments to double. Review Comment: ```suggestion /// A RoundFunction that promotes only decimal arguments to double. ``` ########## cpp/src/arrow/compute/kernels/scalar_round.cc: ########## @@ -452,6 +468,127 @@ struct Round<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> { } }; +template <typename ArrowType, RoundMode RndMode, typename Enable = void> +struct RoundBinary { + using CType = typename TypeTraits<ArrowType>::CType; + using State = RoundOptionsWrapper<RoundBinaryOptions>; + + explicit RoundBinary(const State& state, const DataType& out_ty) {} + + template <typename T = ArrowType, typename CType0 = typename TypeTraits<T>::CType0, + typename CType1 = typename TypeTraits<T>::CType1> + enable_if_floating_value<CType> Call(KernelContext* ctx, CType0 arg0, CType1 arg1, + Status* st) const { + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. Review Comment: Do you have any tests with infinite or NaN? ########## cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc: ########## @@ -965,6 +1086,97 @@ TYPED_TEST(TestUnaryRoundFloating, Round) { } } +TYPED_TEST_SUITE(TestBinaryRoundIntegral, IntegralTypes); +TYPED_TEST_SUITE(TestBinaryRoundSigned, SignedIntegerTypes); +TYPED_TEST_SUITE(TestBinaryRoundUnsigned, UnsignedIntegerTypes); +TYPED_TEST_SUITE(TestBinaryRoundFloating, FloatingTypes); + +TYPED_TEST(TestBinaryRoundSigned, Round) { + // Test different rounding modes for integer rounding + std::string values("[0, 1, -13, -50, 115]"); + for (const auto& round_mode : kRoundModes) { + this->SetRoundMode(round_mode); + this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(), values)); + } + + // Test different round N-digits for nearest rounding mode + std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{ + {-2, "[0.0, 0.0, -0.0, -100, 100]"}, + {-1, "[0.0, 0.0, -10, -50, 120]"}, + {0, values}, + {1, values}, + {2, values}, + }}; + this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY); + for (const auto& pair : ndigits_and_expected) { + this->AssertBinaryOp(RoundBinary, values, pair.first, + ArrayFromJSON(float64(), pair.second)); + } +} + +TYPED_TEST(TestBinaryRoundUnsigned, Round) { + // Test different rounding modes for integer rounding + std::string values("[0, 1, 13, 50, 115]"); + for (const auto& round_mode : kRoundModes) { + this->SetRoundMode(round_mode); + this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(), values)); + } + + // Test different round N-digits for nearest rounding mode + std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{ + {-2, "[0, 0, 0, 100, 100]"}, + {-1, "[0, 0, 10, 50, 120]"}, + {0, values}, + {1, values}, + {2, values}, + }}; + this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY); + for (const auto& pair : ndigits_and_expected) { + this->AssertBinaryOp(RoundBinary, values, pair.first, + ArrayFromJSON(float64(), pair.second)); + } +} + +TYPED_TEST(TestBinaryRoundFloating, Round) { + this->SetNansEqual(true); + + // Test different rounding modes + std::string values("[3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7]"); + std::vector<std::pair<RoundMode, std::string>> rmode_and_expected{{ + {RoundMode::DOWN, "[3, 3, 3, 4, -4, -4, -4]"}, + {RoundMode::UP, "[4, 4, 4, 5, -3, -3, -3]"}, + {RoundMode::TOWARDS_ZERO, "[3, 3, 3, 4, -3, -3, -3]"}, + {RoundMode::TOWARDS_INFINITY, "[4, 4, 4, 5, -4, -4, -4]"}, + {RoundMode::HALF_DOWN, "[3, 3, 4, 4, -3, -4, -4]"}, + {RoundMode::HALF_UP, "[3, 4, 4, 5, -3, -3, -4]"}, + {RoundMode::HALF_TOWARDS_ZERO, "[3, 3, 4, 4, -3, -3, -4]"}, + {RoundMode::HALF_TOWARDS_INFINITY, "[3, 4, 4, 5, -3, -4, -4]"}, + {RoundMode::HALF_TO_EVEN, "[3, 4, 4, 4, -3, -4, -4]"}, + {RoundMode::HALF_TO_ODD, "[3, 3, 4, 5, -3, -3, -4]"}, + }}; + for (const auto& pair : rmode_and_expected) { + this->SetRoundMode(pair.first); + this->AssertBinaryOp(RoundBinary, "[]", "[]", "[]"); + this->AssertBinaryOp(RoundBinary, "[null, 0, Inf, -Inf, NaN, -NaN]", + "[0, 0, 0, 0, 0, 0]", "[null, 0, Inf, -Inf, NaN, -NaN]"); + this->AssertBinaryOp(RoundBinary, values, 0, pair.second); + } + + // Test different round N-digits for nearest rounding mode + values = "[320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045]"; + std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{ + {-2, "[300, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0]"}, + {-1, "[320, 0.0, 0.0, 0.0, -0.0, -40, -0.0]"}, + {0, "[320, 4, 3, 5, -3, -35, -3]"}, + {1, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3]"}, + {2, "[320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05]"}, + }}; + this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY); + for (const auto& pair : ndigits_and_expected) { + this->AssertBinaryOp(RoundBinary, values, pair.first, pair.second); + } +} Review Comment: Can you add some unit tests that consider nulls in both the values and the num_digits arguments? Also maybe a few tests with scalars (esp. using a scalar for num_digits and an array for values which should be equivalent to unary rounding) ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessUncheckedArithmetic( }; } +ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode( + const std::string& function_name) { + return [function_name](const SubstraitCall& call) -> Result<compute::Expression> { + ARROW_ASSIGN_OR_RAISE( + compute::RoundMode round_mode, + ParseOptionOrElse( + call, "rounding", kRoundModeParser, + {compute::RoundMode::DOWN, compute::RoundMode::UP, + compute::RoundMode::TOWARDS_ZERO, compute::RoundMode::TOWARDS_INFINITY, + compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP, + compute::RoundMode::HALF_TOWARDS_ZERO, + compute::RoundMode::HALF_TOWARDS_INFINITY, compute::RoundMode::HALF_TO_EVEN, + compute::RoundMode::HALF_TO_ODD}, + compute::RoundMode::HALF_TOWARDS_INFINITY)); Review Comment: It appears you are defaulting to `HALF_TOWARDS_INFINITY` but shouldn't the default be `HALF_TO_EVEN`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org