mapleFU commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1715647340
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1309,168 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an `is_in` value set against an inequality guarantee.
+ ///
+ /// Simplifying an `is_in` predicate involves filtering out any values from
+ /// the value set that cannot possibly be found given the guarantee. For
+ /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the
guarantee
+ /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+ /// This can be done efficiently if the value set is sorted and unique by
+ /// binary searching the inequality gound and slicing the value set
+ /// accordingly.
+ ///
+ /// \pre `guarantee` is non-nullable
+ /// \pre `guarantee.bound` is a scalar
+ /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+ /// \return a simplified value set, or a bool if the simplification of the
value set
+ /// means the whole is_in expr can become a boolean literal.
+ template <typename ArrowType>
+ static Result<std::variant<std::shared_ptr<Array>, bool>>
SimplifyIsInValueSet(
+ const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ DCHECK(guarantee.bound.is_scalar());
+
+ if (value_set->length() == 0) return false;
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar_bound,
+ guarantee.bound.scalar()->CastTo(value_set->type()));
+ auto bound = internal::UnboxScalar<ArrowType>::Unbox(*scalar_bound);
+ auto compare = [&bound, &value_set](size_t i) -> Comparison::type {
+ DCHECK(value_set->IsValid(i));
+ auto value = checked_pointer_cast<ArrayType>(value_set)->GetView(i);
+ return value == bound ? Comparison::EQUAL
+ : value < bound ? Comparison::LESS
+ : Comparison::GREATER;
+ };
+
+ size_t lo = 0;
+ size_t hi = value_set->length();
+ while (lo + 1 < hi) {
+ size_t mid = (lo + hi) / 2;
+ Comparison::type cmp = compare(mid);
+ if (cmp & Comparison::LESS_EQUAL) {
+ lo = mid;
+ } else {
+ hi = mid;
+ }
+ }
+
+ Comparison::type cmp = compare(lo);
+ size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+ bool found = cmp == Comparison::EQUAL;
+
+ switch (guarantee.cmp) {
+ case Comparison::EQUAL:
+ return found;
+ case Comparison::LESS:
+ value_set = value_set->Slice(0, pivot);
+ break;
+ case Comparison::LESS_EQUAL:
+ value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+ break;
+ case Comparison::GREATER:
+ value_set = value_set->Slice(pivot + (found ? 1 : 0));
+ break;
+ case Comparison::GREATER_EQUAL:
+ value_set = value_set->Slice(pivot);
+ break;
+ case Comparison::NOT_EQUAL:
+ case Comparison::NA:
+ DCHECK(false);
+ return Status::Invalid("Invalid comparison");
+ }
+
+ if (value_set->length() == 0) return false;
+ return value_set;
+ }
+
+ /// Simplify an `is_in` call against an inequality guarantee.
+ /// \pre `is_in_call` is a call to the `is_in` function
+ /// \post updates the simplification context with the simplified value set if
+ /// nullopt is returned, otherwise ensures that the call is removed
from
+ /// the simplification context
+ /// \return a boolean if the whole `is_in` call simplifies to a boolean
literal,
+ /// otherwise nullopt
+ static Result<std::optional<bool>> SimplifyIsIn(
+ const Inequality& guarantee,
+ const Expression::Call* is_in_call,
+ SimplificationContext& context) {
+ DCHECK_EQ(is_in_call->function_name, "is_in");
+
+ // Null-matching behavior is complex and reduces the chances of reduction
+ // of `is_in` calls to a single literal for every possible input, so we
+ // abort the simplification if nulls are possible in the input.
+ if (guarantee.nullable) return std::nullopt;
+
+ if (!guarantee.bound.is_scalar()) {
+ return Status::Invalid("Cannot simplify inequality on a non-scalar
bound");
+ }
+
+ const auto& lhs =
Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
+ if (!lhs.field_ref()) return std::nullopt;
+ if (*lhs.field_ref() != guarantee.target) return std::nullopt;
+
+ auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);
+ std::array<TypeHolder, 2> types{guarantee.bound.type().get(),
+ options->value_set.type().get()};
+ TypeHolder cmp_type;
+ if (types[0] == types[1]) cmp_type = types[0];
+ if (!cmp_type) cmp_type = internal::CommonNumeric(types.data(),
types.size());
+ if (!cmp_type) cmp_type = internal::CommonTemporal(types.data(),
types.size());
+ if (!cmp_type) cmp_type = internal::CommonBinary(types.data(),
types.size());
+ if (!cmp_type) return std::nullopt;
+
+ std::variant<std::shared_ptr<Array>, bool> result;
+
+#define CASE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: { \
+ ARROW_ASSIGN_OR_RAISE( \
+ std::shared_ptr<Array> value_set, \
+ GetIsInValueSetForSimplification(is_in_call, cmp_type, context)); \
+ ARROW_ASSIGN_OR_RAISE( \
+ result, \
+ SimplifyIsInValueSet<TYPE_CLASS##Type>(guarantee, value_set)); \
+ break; \
+ }
+
+ switch (cmp_type.id()) {
+ CASE(UInt8)
Review Comment:
Just a naive question, would decimal could be checked here? ( or just not
support it yet?)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]