pitrou commented on code in PR #12660:
URL: https://github.com/apache/arrow/pull/12660#discussion_r850635511


##########
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc:
##########
@@ -1200,70 +1201,71 @@ template <>
 struct RoundOptionsWrapper<RoundToMultipleOptions>
     : public OptionsWrapper<RoundToMultipleOptions> {
   using OptionsType = RoundToMultipleOptions;
-  using State = RoundOptionsWrapper<OptionsType>;
   using OptionsWrapper::OptionsWrapper;
 
   static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
                                                    const KernelInitArgs& args) 
{
-    std::unique_ptr<State> state;
-    if (auto options = static_cast<const OptionsType*>(args.options)) {
-      state = ::arrow::internal::make_unique<State>(*options);
-    } else {
+    auto options = static_cast<const OptionsType*>(args.options);
+    if (!options) {
       return Status::Invalid(
           "Attempted to initialize KernelState from null FunctionOptions");
     }
 
-    auto options = Get(*state);
-    const auto& type = *args.inputs[0].type;
-    if (!options.multiple || !options.multiple->is_valid) {
+    const auto& multiple = options->multiple;
+    if (!multiple || !multiple->is_valid) {
       return Status::Invalid("Rounding multiple must be non-null and valid");
     }
-    if (is_floating(type.id())) {
-      switch (options.multiple->type->id()) {
-        case Type::FLOAT: {
-          if (UnboxScalar<FloatType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DOUBLE: {
-          if (UnboxScalar<DoubleType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::HALF_FLOAT:
-          return Status::NotImplemented("Half-float values are not supported");
-        default:
-          return Status::Invalid("Rounding multiple must be a ", type, " 
scalar, not ",
-                                 *options.multiple->type);
-      }
+
+    // Ensure the rounding multiple option matches the kernel's output type.
+    // The output type is not available here so we use the following rule:
+    // If `multiple` is neither a floating-point nor a decimal type, then
+    // cast to float64, else cast to the kernel's input type.
+    std::shared_ptr<Scalar> resolved_multiple;
+    const auto& to_type =
+        (!is_floating(multiple->type->id()) && 
!is_decimal(multiple->type->id()))
+            ? float64()
+            : args.inputs[0].type;
+    bool is_casted = false;
+    if (!multiple->type->Equals(to_type)) {
+      ARROW_ASSIGN_OR_RAISE(
+          auto casted_multiple,
+          Cast(Datum(multiple), to_type, CastOptions::Safe(), 
ctx->exec_context()));
+      resolved_multiple = casted_multiple.scalar();
+      is_casted = true;
     } else {
-      DCHECK(is_decimal(type.id()));
-      if (!type.Equals(*options.multiple->type)) {
-        return Status::Invalid("Rounding multiple must be a ", type, " scalar, 
not ",
-                               *options.multiple->type);
-      }
-      switch (options.multiple->type->id()) {
-        case Type::DECIMAL128: {
-          if (UnboxScalar<Decimal128Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DECIMAL256: {
-          if (UnboxScalar<Decimal256Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        default:
-          // This shouldn't happen
-          return Status::Invalid("Rounding multiple must be a ", type, " 
scalar, not ",
-                                 *options.multiple->type);
-      }
+      resolved_multiple = multiple;
     }
-    return std::move(state);
+
+    // NOTE: The positive value check can be simplified by using a comparison 
kernel.
+    bool is_negative = false;
+    switch (resolved_multiple->type->id()) {

Review Comment:
   I think we want to avoid calling a kernel for a simple operation on a 
scalar, as the kernel execution overhead is large.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to