westonpace commented on code in PR #33775:
URL: https://github.com/apache/arrow/pull/33775#discussion_r1083118332


##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
 Result<Datum> Round(const Datum& arg, RoundOptions options = 
RoundOptions::Defaults(),
                     ExecContext* ctx = NULLPTR);
 
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.

Review Comment:
   Which argument?  I assume the output is null if either argument is null?  
Can we be more explicit.



##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -113,7 +112,7 @@ class TestBaseUnaryRoundArithmetic : public ::testing::Test 
{
   // (Array, Array)
   void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
                      const std::shared_ptr<Array>& expected) {
-    ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
+    ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr))

Review Comment:
   While this does compile and work we try and add a `;` to the end of these 
kinds of macros anyways for readability
   ```suggestion
       ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
   ```



##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow 
DecodeOptionlessUncheckedArithmetic(
   };
 }
 
+ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode(
+    const std::string& function_name) {
+  return [function_name](const SubstraitCall& call) -> 
Result<compute::Expression> {
+    ARROW_ASSIGN_OR_RAISE(
+        compute::RoundMode round_mode,
+        ParseOptionOrElse(
+            call, "rounding", kRoundModeParser,
+            {compute::RoundMode::DOWN, compute::RoundMode::UP,
+             compute::RoundMode::TOWARDS_ZERO, 
compute::RoundMode::TOWARDS_INFINITY,
+             compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP,
+             compute::RoundMode::HALF_TOWARDS_ZERO,
+             compute::RoundMode::HALF_TOWARDS_INFINITY, 
compute::RoundMode::HALF_TO_EVEN,
+             compute::RoundMode::HALF_TO_ODD},
+            compute::RoundMode::HALF_TOWARDS_INFINITY));
+    ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+                          GetValueArgs(call, 0));
+    std::shared_ptr<compute::RoundBinaryOptions> options =

Review Comment:
   Do we want to optimize and call the unary round if the second value is a 
scalar?  If not in this PR can we create a follow-up github issue so we don't 
lose track of it?  Or maybe round_binary itself can fallback to unary rounding 
if the second argument is scalar.



##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
 Result<Datum> Round(const Datum& arg, RoundOptions options = 
RoundOptions::Defaults(),
                     ExecContext* ctx = NULLPTR);
 
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded

Review Comment:
   ```suggestion
   /// \param[in] arg1 the value to be rounded
   ```



##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
 Result<Datum> Round(const Datum& arg, RoundOptions options = 
RoundOptions::Defaults(),
                     ExecContext* ctx = NULLPTR);
 
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded
+/// \param[in] arg2 the number of significant digits to round to
+/// \param[in] options rounding options (rounding mode and number of digits), 
optional

Review Comment:
   ```suggestion
   /// \param[in] options rounding options (rounding mode), optional
   ```
   Or just get rid of the parentheses section entirely.



##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
 Result<Datum> Round(const Datum& arg, RoundOptions options = 
RoundOptions::Defaults(),
                     ExecContext* ctx = NULLPTR);
 
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded
+/// \param[in] arg2 the number of significant digits to round to

Review Comment:
   Can this be negative?  Do we define elsewhere what that entails?



##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -18,7 +18,6 @@
 #include <algorithm>
 #include <cmath>
 #include <memory>
-#include <string>

Review Comment:
   Our guideline for includes is [`iwyu`](https://include-what-you-use.org/).  
We don't always follow it perfectly (the conformance tool doesn't like type_fwd 
files) but it is what we aim for.  Please don't remove includes if they are 
used in the file (I still see many instances of `std::string`) even if the file 
compiles otherwise (transitive includes are potentially unstable).



##########
cpp/src/arrow/compute/kernels/scalar_round.cc:
##########
@@ -751,60 +877,25 @@ ArrayKernelExec 
GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id)
   }
 }
 
-struct ArithmeticFunction : ScalarFunction {
+struct RoundFunction : ScalarFunction {
   using ScalarFunction::ScalarFunction;
 
   Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const 
override {
     RETURN_NOT_OK(CheckArity(types->size()));
 
-    RETURN_NOT_OK(CheckDecimals(types));
-
     using arrow::compute::detail::DispatchExactImpl;
     if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
 
     EnsureDictionaryDecoded(types);
 
-    // Only promote types for binary functions
-    if (types->size() == 2) {
-      ReplaceNullWithOtherType(types);
-      TimeUnit::type finest_unit;
-      if (CommonTemporalResolution(types->data(), types->size(), 
&finest_unit)) {
-        ReplaceTemporalTypes(finest_unit, types);
-      } else {
-        if (TypeHolder type = CommonNumeric(*types)) {
-          ReplaceTypes(type, types);
-        }
-      }
-    }
-
     if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
     return arrow::compute::detail::NoMatchingKernel(this, *types);
   }
-
-  Status CheckDecimals(std::vector<TypeHolder>* types) const {
-    if (!HasDecimal(*types)) return Status::OK();
-
-    if (types->size() == 2) {
-      // "add_checked" -> "add"
-      const auto func_name = name();
-      const std::string op = func_name.substr(0, func_name.find("_"));
-      if (op == "add" || op == "subtract") {
-        return CastBinaryDecimalArgs(DecimalPromotion::kAdd, types);
-      } else if (op == "multiply") {
-        return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, types);
-      } else if (op == "divide") {
-        return CastBinaryDecimalArgs(DecimalPromotion::kDivide, types);
-      } else {
-        return Status::Invalid("Invalid decimal function: ", func_name);
-      }
-    }
-    return Status::OK();
-  }
 };
 
-/// An ArithmeticFunction that promotes only decimal arguments to double.
-struct ArithmeticDecimalToFloatingPointFunction : public ArithmeticFunction {
-  using ArithmeticFunction::ArithmeticFunction;
+/// An RoundFunction that promotes only decimal arguments to double.

Review Comment:
   ```suggestion
   /// A RoundFunction that promotes only decimal arguments to double.
   ```



##########
cpp/src/arrow/compute/kernels/scalar_round.cc:
##########
@@ -452,6 +468,127 @@ struct Round<ArrowType, kRoundMode, 
enable_if_decimal<ArrowType>> {
   }
 };
 
+template <typename ArrowType, RoundMode RndMode, typename Enable = void>
+struct RoundBinary {
+  using CType = typename TypeTraits<ArrowType>::CType;
+  using State = RoundOptionsWrapper<RoundBinaryOptions>;
+
+  explicit RoundBinary(const State& state, const DataType& out_ty) {}
+
+  template <typename T = ArrowType, typename CType0 = typename 
TypeTraits<T>::CType0,
+            typename CType1 = typename TypeTraits<T>::CType1>
+  enable_if_floating_value<CType> Call(KernelContext* ctx, CType0 arg0, CType1 
arg1,
+                                       Status* st) const {
+    // Do not process Inf or NaN because they will trigger the overflow error 
at end of
+    // function.

Review Comment:
   Do you have any tests with infinite or NaN?



##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -965,6 +1086,97 @@ TYPED_TEST(TestUnaryRoundFloating, Round) {
   }
 }
 
+TYPED_TEST_SUITE(TestBinaryRoundIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestBinaryRoundSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryRoundUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryRoundFloating, FloatingTypes);
+
+TYPED_TEST(TestBinaryRoundSigned, Round) {
+  // Test different rounding modes for integer rounding
+  std::string values("[0, 1, -13, -50, 115]");
+  for (const auto& round_mode : kRoundModes) {
+    this->SetRoundMode(round_mode);
+    this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(), 
values));
+  }
+
+  // Test different round N-digits for nearest rounding mode
+  std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+      {-2, "[0.0, 0.0, -0.0, -100, 100]"},
+      {-1, "[0.0, 0.0, -10, -50, 120]"},
+      {0, values},
+      {1, values},
+      {2, values},
+  }};
+  this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+  for (const auto& pair : ndigits_and_expected) {
+    this->AssertBinaryOp(RoundBinary, values, pair.first,
+                         ArrayFromJSON(float64(), pair.second));
+  }
+}
+
+TYPED_TEST(TestBinaryRoundUnsigned, Round) {
+  // Test different rounding modes for integer rounding
+  std::string values("[0, 1, 13, 50, 115]");
+  for (const auto& round_mode : kRoundModes) {
+    this->SetRoundMode(round_mode);
+    this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(), 
values));
+  }
+
+  // Test different round N-digits for nearest rounding mode
+  std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+      {-2, "[0, 0, 0, 100, 100]"},
+      {-1, "[0, 0, 10, 50, 120]"},
+      {0, values},
+      {1, values},
+      {2, values},
+  }};
+  this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+  for (const auto& pair : ndigits_and_expected) {
+    this->AssertBinaryOp(RoundBinary, values, pair.first,
+                         ArrayFromJSON(float64(), pair.second));
+  }
+}
+
+TYPED_TEST(TestBinaryRoundFloating, Round) {
+  this->SetNansEqual(true);
+
+  // Test different rounding modes
+  std::string values("[3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7]");
+  std::vector<std::pair<RoundMode, std::string>> rmode_and_expected{{
+      {RoundMode::DOWN, "[3, 3, 3, 4, -4, -4, -4]"},
+      {RoundMode::UP, "[4, 4, 4, 5, -3, -3, -3]"},
+      {RoundMode::TOWARDS_ZERO, "[3, 3, 3, 4, -3, -3, -3]"},
+      {RoundMode::TOWARDS_INFINITY, "[4, 4, 4, 5, -4, -4, -4]"},
+      {RoundMode::HALF_DOWN, "[3, 3, 4, 4, -3, -4, -4]"},
+      {RoundMode::HALF_UP, "[3, 4, 4, 5, -3, -3, -4]"},
+      {RoundMode::HALF_TOWARDS_ZERO, "[3, 3, 4, 4, -3, -3, -4]"},
+      {RoundMode::HALF_TOWARDS_INFINITY, "[3, 4, 4, 5, -3, -4, -4]"},
+      {RoundMode::HALF_TO_EVEN, "[3, 4, 4, 4, -3, -4, -4]"},
+      {RoundMode::HALF_TO_ODD, "[3, 3, 4, 5, -3, -3, -4]"},
+  }};
+  for (const auto& pair : rmode_and_expected) {
+    this->SetRoundMode(pair.first);
+    this->AssertBinaryOp(RoundBinary, "[]", "[]", "[]");
+    this->AssertBinaryOp(RoundBinary, "[null, 0, Inf, -Inf, NaN, -NaN]",
+                         "[0, 0, 0, 0, 0, 0]", "[null, 0, Inf, -Inf, NaN, 
-NaN]");
+    this->AssertBinaryOp(RoundBinary, values, 0, pair.second);
+  }
+
+  // Test different round N-digits for nearest rounding mode
+  values = "[320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045]";
+  std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+      {-2, "[300, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0]"},
+      {-1, "[320, 0.0, 0.0, 0.0, -0.0, -40, -0.0]"},
+      {0, "[320, 4, 3, 5, -3, -35, -3]"},
+      {1, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3]"},
+      {2, "[320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05]"},
+  }};
+  this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+  for (const auto& pair : ndigits_and_expected) {
+    this->AssertBinaryOp(RoundBinary, values, pair.first, pair.second);
+  }
+}

Review Comment:
   Can you add some unit tests that consider nulls in both the values and the 
num_digits arguments?  Also maybe a few tests with scalars (esp. using a scalar 
for num_digits and an array for values which should be equivalent to unary 
rounding)



##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow 
DecodeOptionlessUncheckedArithmetic(
   };
 }
 
+ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode(
+    const std::string& function_name) {
+  return [function_name](const SubstraitCall& call) -> 
Result<compute::Expression> {
+    ARROW_ASSIGN_OR_RAISE(
+        compute::RoundMode round_mode,
+        ParseOptionOrElse(
+            call, "rounding", kRoundModeParser,
+            {compute::RoundMode::DOWN, compute::RoundMode::UP,
+             compute::RoundMode::TOWARDS_ZERO, 
compute::RoundMode::TOWARDS_INFINITY,
+             compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP,
+             compute::RoundMode::HALF_TOWARDS_ZERO,
+             compute::RoundMode::HALF_TOWARDS_INFINITY, 
compute::RoundMode::HALF_TO_EVEN,
+             compute::RoundMode::HALF_TO_ODD},
+            compute::RoundMode::HALF_TOWARDS_INFINITY));

Review Comment:
   It appears you are defaulting to `HALF_TOWARDS_INFINITY` but shouldn't the 
default be `HALF_TO_EVEN`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to