lidavidm commented on a change in pull request #10390:
URL: https://github.com/apache/arrow/pull/10390#discussion_r641560736



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
##########
@@ -516,6 +565,190 @@ std::shared_ptr<ScalarFunction> 
MakeUnarySignedArithmeticFunctionNotNull(
   return func;
 }
 
+using MinMaxState = OptionsWrapper<ElementWiseAggregateOptions>;
+
+// Implement a variadic scalar min/max kernel.
+template <typename OutType, typename Op>
+struct ScalarMinMax {
+  using OutValue = typename GetOutputType<OutType>::T;
+
+  static void ExecScalar(const ExecBatch& batch,
+                         const ElementWiseAggregateOptions& options, Scalar* 
out) {
+    // All arguments are scalar
+    OutValue value{};
+    bool valid = false;
+    for (const auto& arg : batch.values) {
+      // Ignore non-scalar arguments so we can use it in the 
mixed-scalar-and-array case
+      if (!arg.is_scalar()) continue;
+      const auto& scalar = *arg.scalar();
+      if (!scalar.is_valid) {
+        if (options.skip_nulls) continue;
+        out->is_valid = false;
+        return;
+      }
+      if (!valid) {
+        value = UnboxScalar<OutType>::Unbox(scalar);
+        valid = true;
+      } else {
+        value = Op::Call(value, UnboxScalar<OutType>::Unbox(scalar));
+      }
+    }
+    out->is_valid = valid;
+    if (valid) {
+      BoxScalar<OutType>::Box(value, out);
+    }
+  }
+
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    const ElementWiseAggregateOptions& options = MinMaxState::Get(ctx);
+    const auto descrs = batch.GetDescriptors();
+    const size_t scalar_count =
+        static_cast<size_t>(std::count_if(batch.values.begin(), 
batch.values.end(),
+                                          [](const Datum& d) { return 
d.is_scalar(); }));
+    if (scalar_count == batch.values.size()) {
+      ExecScalar(batch, options, out->scalar().get());
+      return Status::OK();
+    }
+
+    ArrayData* output = out->mutable_array();
+
+    // At least one array, two or more arguments
+    ArrayDataVector arrays;
+    for (const auto& arg : batch.values) {
+      if (!arg.is_array()) continue;
+      arrays.push_back(arg.array());
+    }
+
+    if (scalar_count > 0) {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> temp_scalar,
+                            MakeScalar(out->type(), 0));
+      ExecScalar(batch, options, temp_scalar.get());
+      if (temp_scalar->is_valid) {
+        // Promote to output array
+        ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*temp_scalar, 
batch.length,
+                                                              
ctx->memory_pool()));
+        arrays.push_back(array->data());
+      } else if (!options.skip_nulls) {
+        // Abort early
+        ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*temp_scalar, 
batch.length,
+                                                              
ctx->memory_pool()));
+        *output = *array->data();
+        return Status::OK();
+      }
+    }
+
+    // Exactly one array to consider (output = input)
+    if (arrays.size() == 1) {
+      *output = *arrays[0];
+      return Status::OK();
+    }
+
+    // Two or more arrays to consider
+    if (scalar_count > 0) {
+      // We allocated the last array from a scalar: recycle it as the output

Review comment:
       Good catch, thanks.

##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
##########
@@ -1161,5 +1255,242 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) {
   }
 }
 
+TYPED_TEST(TestVarArgsArithmeticNumeric, Minimum) {
+  this->AssertNullScalar(Minimum, {});
+  this->AssertNullScalar(Minimum, {this->scalar("null"), 
this->scalar("null")});
+
+  this->Assert(Minimum, this->scalar("0"), {this->scalar("0")});
+  this->Assert(Minimum, this->scalar("0"),
+               {this->scalar("2"), this->scalar("0"), this->scalar("1")});
+  this->Assert(
+      Minimum, this->scalar("0"),
+      {this->scalar("2"), this->scalar("0"), this->scalar("1"), 
this->scalar("null")});
+  this->Assert(Minimum, this->scalar("1"),
+               {this->scalar("null"), this->scalar("null"), this->scalar("1"),
+                this->scalar("null")});
+
+  this->Assert(Minimum, (this->array("[]")), {this->array("[]")});
+  this->Assert(Minimum, this->array("[1, 2, 3, null]"), {this->array("[1, 2, 
3, null]")});
+
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, 2, 3, 4]"), this->scalar("2")});
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, null, 3, 4]"), this->scalar("2")});
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, null, 3, 4]"), this->scalar("2"), 
this->scalar("4")});
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, null, 3, 4]"), this->scalar("null"), 
this->scalar("2")});
+
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")});
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+  this->Assert(Minimum, this->array("[1, 2, 2, 2]"),
+               {this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")});
+
+  this->Assert(Minimum, this->array("[1, 2, null, 6]"),
+               {this->array("[1, 2, null, null]"), this->array("[4, null, 
null, 6]")});
+  this->Assert(Minimum, this->array("[1, 2, null, 6]"),
+               {this->array("[4, null, null, 6]"), this->array("[1, 2, null, 
null]")});
+  this->Assert(Minimum, this->array("[1, 2, 3, 4]"),
+               {this->array("[1, 2, 3, 4]"), this->array("[null, null, null, 
null]")});
+  this->Assert(Minimum, this->array("[1, 2, 3, 4]"),
+               {this->array("[null, null, null, null]"), this->array("[1, 2, 
3, 4]")});
+
+  this->Assert(Minimum, this->array("[1, 1, 1, 1]"),
+               {this->scalar("1"), this->array("[1, 2, 3, 4]")});
+  this->Assert(Minimum, this->array("[1, 1, 1, 1]"),
+               {this->scalar("1"), this->array("[null, null, null, null]")});
+  this->Assert(Minimum, this->array("[1, 1, 1, 1]"),
+               {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+  this->Assert(Minimum, this->array("[null, null, null, null]"),
+               {this->scalar("null"), this->array("[null, null, null, 
null]")});
+
+  // Test null handling
+  this->element_wise_aggregate_options_.skip_nulls = false;
+  this->AssertNullScalar(Minimum, {this->scalar("null"), 
this->scalar("null")});
+  this->AssertNullScalar(Minimum, {this->scalar("0"), this->scalar("null")});
+
+  this->Assert(Minimum, this->array("[1, null, 2, 2]"),
+               {this->array("[1, null, 3, 4]"), this->scalar("2"), 
this->scalar("4")});
+  this->Assert(Minimum, this->array("[null, null, null, null]"),
+               {this->array("[1, null, 3, 4]"), this->scalar("null"), 
this->scalar("2")});
+  this->Assert(Minimum, this->array("[1, null, 2, 2]"),
+               {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+
+  this->Assert(Minimum, this->array("[null, null, null, null]"),
+               {this->scalar("1"), this->array("[null, null, null, null]")});
+  this->Assert(Minimum, this->array("[null, null, null, null]"),
+               {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+}
+
+TYPED_TEST(TestVarArgsArithmeticFloating, Minimum) {
+  this->SetNansEqual();
+  this->Assert(Maximum, this->scalar("0"), {this->scalar("0"), 
this->scalar("NaN")});
+  this->Assert(Maximum, this->scalar("0"), {this->scalar("NaN"), 
this->scalar("0")});
+  this->Assert(Maximum, this->scalar("Inf"), {this->scalar("Inf"), 
this->scalar("NaN")});
+  this->Assert(Maximum, this->scalar("Inf"), {this->scalar("NaN"), 
this->scalar("Inf")});
+  this->Assert(Maximum, this->scalar("-Inf"),
+               {this->scalar("-Inf"), this->scalar("NaN")});
+  this->Assert(Maximum, this->scalar("-Inf"),
+               {this->scalar("NaN"), this->scalar("-Inf")});
+  this->Assert(Maximum, this->scalar("NaN"), {this->scalar("NaN"), 
this->scalar("null")});
+  this->Assert(Minimum, this->scalar("0"), {this->scalar("0"), 
this->scalar("Inf")});
+  this->Assert(Minimum, this->scalar("-Inf"), {this->scalar("0"), 
this->scalar("-Inf")});

Review comment:
       Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to