asuhan commented on a change in pull request #11793:
URL: https://github.com/apache/arrow/pull/11793#discussion_r759492451



##########
File path: cpp/src/arrow/compute/kernels/scalar_compare.cc
##########
@@ -439,6 +469,325 @@ struct ScalarMinMax {
   }
 };
 
+template <typename Type, typename Op>
+struct BinaryScalarMinMax {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    const ElementWiseAggregateOptions& options = MinMaxState::Get(ctx);
+    if (std::all_of(batch.values.begin(), batch.values.end(),
+                    [](const Datum& d) { return d.is_scalar(); })) {
+      return ExecOnlyScalar(ctx, options, batch, out);
+    }
+    return ExecContainingArrays(ctx, options, batch, out);
+  }
+
+  static Status ExecOnlyScalar(KernelContext* ctx,
+                               const ElementWiseAggregateOptions& options,
+                               const ExecBatch& batch, Datum* out) {
+    if (batch.values.empty()) {
+      return Status::OK();
+    }
+    BaseBinaryScalar* output = 
checked_cast<BaseBinaryScalar*>(out->scalar().get());
+    const size_t num_args = batch.values.size();
+
+    int64_t final_size = CalculateRowSize(options, batch, 0);
+    if (final_size < 0) {
+      output->is_valid = false;
+      return Status::OK();
+    }
+    util::string_view result = 
UnboxScalar<Type>::Unbox(*batch.values.front().scalar());
+    for (size_t i = 1; i < num_args; i++) {
+      const Scalar& scalar = *batch[i].scalar();
+      if (!scalar.is_valid && options.skip_nulls) {
+        continue;
+      }
+      if (scalar.is_valid) {
+        util::string_view value = UnboxScalar<Type>::Unbox(scalar);
+        result = result.empty() ? value : Op::CallBinary(result, value);
+      }
+    }
+    if (!result.empty()) {
+      ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
+      uint8_t* buf = output->value->mutable_data();
+      buf = std::copy(result.begin(), result.end(), buf);
+      output->is_valid = true;
+      DCHECK_GE(final_size, buf - output->value->mutable_data());
+    }
+    return Status::OK();
+  }
+
+  static Status ExecContainingArrays(KernelContext* ctx,
+                                     const ElementWiseAggregateOptions& 
options,
+                                     const ExecBatch& batch, Datum* out) {
+    // Presize data to avoid reallocations
+    int64_t final_size = 0;
+    for (int64_t i = 0; i < batch.length; i++) {
+      auto size = CalculateRowSize(options, batch, i);
+      if (size > 0) final_size += size;
+    }
+    BuilderType builder(ctx->memory_pool());
+    RETURN_NOT_OK(builder.Reserve(batch.length));
+    RETURN_NOT_OK(builder.ReserveData(final_size));
+
+    std::vector<util::string_view> valid_cols(batch.values.size());
+    for (size_t row = 0; row < static_cast<size_t>(batch.length); row++) {
+      size_t num_valid = 0;
+      for (size_t col = 0; col < batch.values.size(); col++) {
+        if (batch[col].is_scalar()) {
+          const auto& scalar = *batch[col].scalar();
+          if (scalar.is_valid) {
+            valid_cols[col] = UnboxScalar<Type>::Unbox(scalar);
+            num_valid++;
+          } else {
+            valid_cols[col] = util::string_view();
+          }
+        } else {
+          const ArrayData& array = *batch[col].array();
+          if (!array.MayHaveNulls() ||
+              BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
+            const offset_type* offsets = array.GetValues<offset_type>(1);
+            const uint8_t* data = array.GetValues<uint8_t>(2, 
/*absolute_offset=*/0);
+            const int64_t length = offsets[row + 1] - offsets[row];
+            valid_cols[col] = util::string_view(
+                reinterpret_cast<const char*>(data + offsets[row]), length);
+            num_valid++;
+          } else {
+            valid_cols[col] = util::string_view();
+          }
+        }
+      }
+
+      if (num_valid < batch.values.size() && !options.skip_nulls) {
+        // We had some nulls
+        builder.UnsafeAppendNull();
+        continue;
+      }
+      util::string_view result = valid_cols.front();
+      for (size_t col = 1; col < batch.values.size(); ++col) {
+        util::string_view value = valid_cols[col];
+        if (value.empty()) {
+          DCHECK(options.skip_nulls);
+          continue;
+        }
+        result = result.empty() ? value : Op::CallBinary(result, value);
+      }
+      if (result.empty()) {
+        builder.UnsafeAppendNull();
+      } else {
+        builder.UnsafeAppend(result);
+      }
+    }
+
+    std::shared_ptr<Array> string_array;
+    RETURN_NOT_OK(builder.Finish(&string_array));
+    *out = *string_array->data();
+    out->mutable_array()->type = batch[0].type();
+    DCHECK_EQ(batch.length, out->array()->length);
+    DCHECK_GE(final_size,
+              checked_cast<const 
ArrayType&>(*string_array).total_values_length());
+    return Status::OK();
+  }
+
+  // Compute the length of the output for the given position, or -1 if it 
would be null.
+  static int64_t CalculateRowSize(const ElementWiseAggregateOptions& options,
+                                  const ExecBatch& batch, const int64_t index) 
{
+    const auto num_args = batch.values.size();
+    int64_t final_size = 0;
+    for (size_t i = 0; i < num_args; i++) {
+      int64_t element_size = 0;
+      bool valid = true;
+      if (batch[i].is_scalar()) {
+        const Scalar& scalar = *batch[i].scalar();
+        valid = scalar.is_valid;
+        element_size = UnboxScalar<Type>::Unbox(scalar).size();
+      } else {
+        const ArrayData& array = *batch[i].array();
+        valid = !array.MayHaveNulls() ||
+                BitUtil::GetBit(array.buffers[0]->data(), array.offset + 
index);
+        const offset_type* offsets = array.GetValues<offset_type>(1);
+        element_size = offsets[index + 1] - offsets[index];
+      }
+      if (!valid) {
+        if (options.skip_nulls) {
+          continue;
+        }
+        return -1;
+      }
+      final_size = std::max(final_size, element_size);

Review comment:
       I forgot to add a comment here, but this is intended. My understanding 
is that we only need the size in order to avoid reallocation, so we can safely 
overestimate it. Computing the precise value would double the amount of 
comparisons, which I was reluctant to do for strings. Should I do that instead?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to