WillAyd commented on code in PR #45272: URL: https://github.com/apache/arrow/pull/45272#discussion_r2205251653
########## cpp/src/arrow/compare_internal.h: ########## @@ -0,0 +1,966 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> + +#include "arrow/array/array_dict.h" +#include "arrow/array/data.h" +#include "arrow/array/diff.h" +#include "arrow/compare.h" +#include "arrow/scalar.h" +#include "arrow/type_traits.h" +#include "arrow/util/binary_view_util.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/float16.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/memory_internal.h" +#include "arrow/util/ree_util.h" +#include "arrow/visit_scalar_inline.h" +#include "arrow/visit_type_inline.h" + +namespace arrow { + +using internal::BitmapEquals; +using internal::BitmapReader; +using internal::BitmapUInt64Reader; +using internal::checked_cast; +using internal::OptionalBitmapEquals; +using util::Float16; + +// TODO also handle HALF_FLOAT NaNs + +template <bool Approximate, bool NansEqual, bool SignedZerosEqual> +struct FloatingEqualityFlags { + static constexpr bool approximate = Approximate; + static constexpr bool nans_equal = NansEqual; + static constexpr bool signed_zeros_equal = SignedZerosEqual; +}; + +template <typename T, typename Flags> +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast<T>(options.atol())) {} + + bool operator()(T x, T y) const { + if (x == y) { + return Flags::signed_zeros_equal || (std::signbit(x) == std::signbit(y)); + } + if (Flags::nans_equal && std::isnan(x) && std::isnan(y)) { + return true; + } + if (Flags::approximate && (fabs(x - y) <= epsilon)) { + return true; + } + return false; + } + + const T epsilon; +}; + +// For half-float equality. +template <typename Flags> +struct FloatingEquality<uint16_t, Flags> { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast<float>(options.atol())) {} + + bool operator()(uint16_t x, uint16_t y) const { + Float16 f_x = Float16::FromBits(x); + Float16 f_y = Float16::FromBits(y); + if (x == y) { + return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit()); + } + if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) { + return true; + } + if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) { + return true; + } + return false; + } + + const float epsilon; +}; + +template <typename T, typename Visitor> +struct FloatingEqualityDispatcher { + const EqualOptions& options; + bool floating_approximate; + Visitor&& visit; + + template <bool Approximate, bool NansEqual> + void DispatchL3() { + if (options.signed_zeros_equal()) { + visit(FloatingEquality<T, FloatingEqualityFlags<Approximate, NansEqual, true>>{ + options}); + } else { + visit(FloatingEquality<T, FloatingEqualityFlags<Approximate, NansEqual, false>>{ + options}); + } + } + + template <bool Approximate> + void DispatchL2() { + if (options.nans_equal()) { + DispatchL3<Approximate, true>(); + } else { + DispatchL3<Approximate, false>(); + } + } + + void Dispatch() { + if (floating_approximate) { + DispatchL2<true>(); + } else { + DispatchL2<false>(); + } + } +}; + +// Call `visit(equality_func)` where `equality_func` has the signature `bool(T, T)` +// and returns true if the two values compare equal. +template <typename T, typename Visitor> +void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate, + Visitor&& visit) { + FloatingEqualityDispatcher<T, Visitor>{options, floating_approximate, + std::forward<Visitor>(visit)} + .Dispatch(); +} + +inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) { + if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) { + return false; + } + for (const auto& child : type.fields()) { + if (!IdentityImpliesEqualityNansNotEqual(*child->type())) { + return false; + } + } + return true; +} + +inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) { + if (options.nans_equal()) { + return true; + } + return IdentityImpliesEqualityNansNotEqual(type); +} + +bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, + int64_t left_start_idx, int64_t left_end_idx, + int64_t right_start_idx, const EqualOptions& options, + bool floating_approximate); + +class RangeDataEqualsImpl { + public: + // PRE-CONDITIONS: + // - the types are equal + // - the ranges are in bounds + // - the ArrayData arguments have the same length + RangeDataEqualsImpl(const EqualOptions& options, bool floating_approximate, + const ArraySpan& left, const ArraySpan& right, + int64_t left_start_idx, int64_t right_start_idx, + int64_t range_length) + : options_(options), + floating_approximate_(floating_approximate), + left_(left), + right_(right), + left_start_idx_(left_start_idx), + right_start_idx_(right_start_idx), + range_length_(range_length), + result_(false) {} + + bool Compare() { + // Compare null bitmaps + if (left_start_idx_ == 0 && right_start_idx_ == 0 && range_length_ == left_.length && + range_length_ == right_.length) { + // If we're comparing entire arrays, we can first compare the cached null counts + if (left_.GetNullCount() != right_.GetNullCount()) { + return false; + } + } + if (!OptionalBitmapEquals(left_.buffers[0].data, left_.offset + left_start_idx_, + right_.buffers[0].data, right_.offset + right_start_idx_, + range_length_)) { + return false; + } + // Compare values + return CompareWithType(*left_.type); + } + + bool CompareWithType(const DataType& type) { + result_ = true; + if (range_length_ != 0) { + ARROW_CHECK_OK(VisitTypeInline(type, this)); + } + return result_; + } + + Status Visit(const NullType&) { return Status::OK(); } + + template <typename TypeClass> + enable_if_primitive_ctype<TypeClass, Status> Visit(const TypeClass& type) { + return ComparePrimitive(type); + } + + template <typename TypeClass> + enable_if_t<is_temporal_type<TypeClass>::value, Status> Visit(const TypeClass& type) { + return ComparePrimitive(type); + } + + Status Visit(const BooleanType&) { + const uint8_t* left_bits = left_.GetValues<uint8_t>(1, 0); + const uint8_t* right_bits = right_.GetValues<uint8_t>(1, 0); + auto compare_runs = [&](int64_t i, int64_t length) -> bool { + if (length <= 8) { + // Avoid the BitmapUInt64Reader overhead for very small runs + for (int64_t j = i; j < i + length; ++j) { + if (bit_util::GetBit(left_bits, left_start_idx_ + left_.offset + j) != + bit_util::GetBit(right_bits, right_start_idx_ + right_.offset + j)) { + return false; + } + } + return true; + } else if (length <= 1024) { + BitmapUInt64Reader left_reader(left_bits, left_start_idx_ + left_.offset + i, + length); + BitmapUInt64Reader right_reader(right_bits, right_start_idx_ + right_.offset + i, + length); + while (left_reader.position() < length) { + if (left_reader.NextWord() != right_reader.NextWord()) { + return false; + } + } + DCHECK_EQ(right_reader.position(), length); + } else { + // BitmapEquals is the fastest method on large runs + return BitmapEquals(left_bits, left_start_idx_ + left_.offset + i, right_bits, + right_start_idx_ + right_.offset + i, length); + } + return true; + }; + VisitValidRuns(compare_runs); + return Status::OK(); + } + + Status Visit(const FloatType& type) { return CompareFloating(type); } + + Status Visit(const DoubleType& type) { return CompareFloating(type); } + + Status Visit(const HalfFloatType& type) { return CompareFloating(type); } + + // Also matches StringType + Status Visit(const BinaryType& type) { return CompareBinary(type); } + + // Also matches StringViewType + Status Visit(const BinaryViewType& type) { + auto* left_values = left_.GetValues<BinaryViewType::c_type>(1) + left_start_idx_; + auto* right_values = right_.GetValues<BinaryViewType::c_type>(1) + right_start_idx_; + + // TODO: the ToArrayData() is wasteful but EqualBinaryView requires an argument + // with a ->data member, so forwarding the raw BufferSpan is not an option at the + // moment + const auto left_buffers = left_.ToArrayData()->buffers.data() + 2; Review Comment: OK so in the cast case we start with string data like: ```cpp "foobarbarbarbar" ... ``` where "foo" is the first word and "barbarbarbarbarbarbar" is the second. When casting to a string view, the second element becomes a reference element that looks like: ```cpp {size = 21, pref = { "barb"}, buffer_index = 0, offset = 3} ``` But is that the correct value? Should the offset not be 0, since the first value of "foo" is elided from the variadic offset and included entirely inline? The line where that is being set is here - maybe we need to accumulate how many values have been inlined for a given buffer and subtract that from the offset? https://github.com/apache/arrow/blob/2bfbfc82cab3998c3b751e515664893b9648d8e5/cpp/src/arrow/compute/kernels/scalar_cast_string.cc#L434 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org