edponce commented on code in PR #12660:
URL: https://github.com/apache/arrow/pull/12660#discussion_r850555033


##########
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc:
##########
@@ -1200,70 +1201,71 @@ template <>
 struct RoundOptionsWrapper<RoundToMultipleOptions>
     : public OptionsWrapper<RoundToMultipleOptions> {
   using OptionsType = RoundToMultipleOptions;
-  using State = RoundOptionsWrapper<OptionsType>;
   using OptionsWrapper::OptionsWrapper;
 
   static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
                                                    const KernelInitArgs& args) 
{
-    std::unique_ptr<State> state;
-    if (auto options = static_cast<const OptionsType*>(args.options)) {
-      state = ::arrow::internal::make_unique<State>(*options);
-    } else {
+    auto options = static_cast<const OptionsType*>(args.options);
+    if (!options) {
       return Status::Invalid(
           "Attempted to initialize KernelState from null FunctionOptions");
     }
 
-    auto options = Get(*state);
-    const auto& type = *args.inputs[0].type;
-    if (!options.multiple || !options.multiple->is_valid) {
+    const auto& multiple = options->multiple;
+    if (!multiple || !multiple->is_valid) {
       return Status::Invalid("Rounding multiple must be non-null and valid");
     }
-    if (is_floating(type.id())) {
-      switch (options.multiple->type->id()) {
-        case Type::FLOAT: {
-          if (UnboxScalar<FloatType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DOUBLE: {
-          if (UnboxScalar<DoubleType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::HALF_FLOAT:
-          return Status::NotImplemented("Half-float values are not supported");
-        default:
-          return Status::Invalid("Rounding multiple must be a ", type, " 
scalar, not ",
-                                 *options.multiple->type);
-      }
+
+    // Ensure the rounding multiple option matches the kernel's output type.
+    // The output type is not available here so we use the following rule:
+    // If `multiple` is neither a floating-point nor a decimal type, then
+    // cast to float64, else cast to the kernel's input type.
+    std::shared_ptr<Scalar> resolved_multiple;
+    const auto& to_type =
+        (!is_floating(multiple->type->id()) && 
!is_decimal(multiple->type->id()))
+            ? float64()
+            : args.inputs[0].type;
+    bool is_casted = false;
+    if (!multiple->type->Equals(to_type)) {
+      ARROW_ASSIGN_OR_RAISE(
+          auto casted_multiple,
+          Cast(Datum(multiple), to_type, CastOptions::Safe(), 
ctx->exec_context()));
+      resolved_multiple = casted_multiple.scalar();
+      is_casted = true;
     } else {
-      DCHECK(is_decimal(type.id()));
-      if (!type.Equals(*options.multiple->type)) {
-        return Status::Invalid("Rounding multiple must be a ", type, " scalar, 
not ",
-                               *options.multiple->type);
-      }
-      switch (options.multiple->type->id()) {
-        case Type::DECIMAL128: {
-          if (UnboxScalar<Decimal128Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DECIMAL256: {
-          if (UnboxScalar<Decimal256Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        default:
-          // This shouldn't happen
-          return Status::Invalid("Rounding multiple must be a ", type, " 
scalar, not ",
-                                 *options.multiple->type);
-      }
+      resolved_multiple = multiple;
     }
-    return std::move(state);
+
+    // NOTE: The positive value check can be simplified by using a comparison 
kernel.
+    bool is_negative = false;
+    switch (resolved_multiple->type->id()) {

Review Comment:
   BTW, I do like the visitor pattern and `xxx.value` idiom.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to