pitrou commented on a change in pull request #10557: URL: https://github.com/apache/arrow/pull/10557#discussion_r661643531
########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, + Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs)); + if (datum_out.is_array()) { + std::shared_ptr<Array> result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); Review comment: Instead use `ValidateOutput` from `kernels/test_util.h`. It will also check that no data is left uninitialized. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -21,6 +21,7 @@ #include <arrow/compute/kernels/test_util.h> #include <arrow/testing/gtest_util.h> #include <gtest/gtest.h> +#include "arrow/compute/registry.h" Review comment: Nit, but would be nice to normalize includes here. I think we normally use `#include "arrow/..."` for intra-Arrow inclusions. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, + Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs)); + if (datum_out.is_array()) { + std::shared_ptr<Array> result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr<Array> expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + + for (int64_t i = 0; i < result->length(); i++) { + // Check scalar + ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i)); + std::vector<Datum> inputs_scalar; + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + auto array = input.make_array(); + ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i)); + inputs_scalar.push_back(input_scalar); + } + } + ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(scalar_out.is_scalar()); + AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true); + + // Check slice + inputs_scalar.clear(); + auto expected_array = expected_->Slice(i); + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + inputs_scalar.push_back(input.make_array()->Slice(i)); + } + } + ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(array_out.is_array()); + AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true); + } + } else { + const std::shared_ptr<Scalar>& result = datum_out.scalar(); + const std::shared_ptr<Scalar>& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } +} + +template <typename Type> +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1, + const std::string& value2) { + auto scalar_true = ScalarFromJSON(boolean(), "true"); + auto scalar_false = ScalarFromJSON(boolean(), "false"); + auto scalar_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto value_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, value1); + auto scalar2 = ScalarFromJSON(type, value2); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + std::stringstream builder; + builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']'; + auto values1 = ArrayFromJSON(type, builder.str()); + builder.str(""); + builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']'; + auto values2 = ArrayFromJSON(type, builder.str()); + // N.B. all-scalar cases are checked in CheckCaseWhen + // Only an else array + CheckVarArgs("case_when", {values1}, values1); + // No else clause, scalar cond, array values + CheckVarArgs("case_when", {scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_false, values1}, values_null); + CheckVarArgs("case_when", {scalar_null, values1}, values_null); + CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1); + CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1); + // No else clause, array cond, scalar values + builder.str(""); + builder << '[' << value1 << ", null, null, null]"; + CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, value_null}, values_null); + builder.str(""); + builder << '[' << value1 << ", null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, builder.str())); + // No else clause, array cond, array values + builder.str(""); + builder << "[null, null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, builder.str())); + // Else clauses/mixed scalar and array + builder.str(""); + builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1}, + ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, builder.str())); +} + +TYPED_TEST(TestCaseWhenNumeric, FixedSize) { + auto type = default_type_instance<TypeParam>(); + CheckCaseWhenCases(type, "10", "42"); +} + +TEST(TestCaseWhen, Null) { + auto scalar = ScalarFromJSON(null(), "null"); + auto array = ArrayFromJSON(null(), "[null, null, null, null]"); + CheckVarArgs("case_when", {array}, array); + CheckVarArgs("case_when", {scalar, array}, array); + CheckVarArgs("case_when", {scalar, array, array}, array); Review comment: Hmm, shouldn't a condition be a boolean? Here you're passing a null `scalar` as the condition. Also, I would expect a test with an array condition. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, + Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs)); + if (datum_out.is_array()) { + std::shared_ptr<Array> result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr<Array> expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + + for (int64_t i = 0; i < result->length(); i++) { + // Check scalar + ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i)); + std::vector<Datum> inputs_scalar; + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + auto array = input.make_array(); + ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i)); + inputs_scalar.push_back(input_scalar); + } + } + ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(scalar_out.is_scalar()); + AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true); + + // Check slice + inputs_scalar.clear(); + auto expected_array = expected_->Slice(i); + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + inputs_scalar.push_back(input.make_array()->Slice(i)); + } + } + ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(array_out.is_array()); + AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true); + } + } else { + const std::shared_ptr<Scalar>& result = datum_out.scalar(); + const std::shared_ptr<Scalar>& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } +} + +template <typename Type> +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1, + const std::string& value2) { + auto scalar_true = ScalarFromJSON(boolean(), "true"); + auto scalar_false = ScalarFromJSON(boolean(), "false"); + auto scalar_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto value_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, value1); + auto scalar2 = ScalarFromJSON(type, value2); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + std::stringstream builder; + builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']'; + auto values1 = ArrayFromJSON(type, builder.str()); + builder.str(""); + builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']'; + auto values2 = ArrayFromJSON(type, builder.str()); + // N.B. all-scalar cases are checked in CheckCaseWhen + // Only an else array + CheckVarArgs("case_when", {values1}, values1); + // No else clause, scalar cond, array values + CheckVarArgs("case_when", {scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_false, values1}, values_null); + CheckVarArgs("case_when", {scalar_null, values1}, values_null); + CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1); + CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1); + // No else clause, array cond, scalar values + builder.str(""); + builder << '[' << value1 << ", null, null, null]"; + CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, value_null}, values_null); + builder.str(""); + builder << '[' << value1 << ", null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, builder.str())); + // No else clause, array cond, array values + builder.str(""); + builder << "[null, null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, builder.str())); + // Else clauses/mixed scalar and array + builder.str(""); + builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1}, + ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, builder.str())); +} + +TYPED_TEST(TestCaseWhenNumeric, FixedSize) { + auto type = default_type_instance<TypeParam>(); + CheckCaseWhenCases(type, "10", "42"); +} + +TEST(TestCaseWhen, Null) { + auto scalar = ScalarFromJSON(null(), "null"); + auto array = ArrayFromJSON(null(), "[null, null, null, null]"); + CheckVarArgs("case_when", {array}, array); + CheckVarArgs("case_when", {scalar, array}, array); + CheckVarArgs("case_when", {scalar, array, array}, array); +} + +TEST(TestCaseWhen, Boolean) { CheckCaseWhenCases(boolean(), "true", "false"); } + +TEST(TestCaseWhen, DayTimeInterval) { + CheckCaseWhenCases(day_time_interval(), "[10, 2]", "[2, 5]"); +} + +TEST(TestCaseWhen, Decimal) { + for (const auto& type : + std::vector<std::shared_ptr<DataType>>{decimal128(3, 2), decimal256(3, 2)}) { + CheckCaseWhenCases(type, "\"1.23\"", "\"4.56\""); + } +} + +TEST(TestCaseWhen, FixedSizeBinary) { + auto type = fixed_size_binary(3); + CheckCaseWhenCases(type, "\"aaa\"", "\"bbb\""); +} + +TEST(TestCaseWhen, DispatchBest) { + auto Check = [](std::vector<ValueDescr> original_values, + std::vector<ValueDescr> expected_equivalent_values) { + EXPECT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction("case_when")); + auto values = original_values; + ARROW_ASSIGN_OR_RAISE(auto actual_kernel, function->DispatchBest(&values)); + EXPECT_OK_AND_ASSIGN(auto expected_kernel, + function->DispatchBest(&expected_equivalent_values)); + EXPECT_EQ(actual_kernel, expected_kernel) + << " DispatchBest" << ValueDescr::ToString(original_values) << " => " + << actual_kernel->signature->ToString() << "\n" + << " DispatchBest" << ValueDescr::ToString(expected_equivalent_values) << " => " + << expected_kernel->signature->ToString(); + return Status::OK(); + }; + + ASSERT_OK(Check({int32()}, {int32()})); + ASSERT_OK(Check({boolean(), int32(), int32()}, {boolean(), int32(), int32()})); + ASSERT_OK(Check({null(), int32(), int32()}, {boolean(), int32(), int32()})); + ASSERT_OK(Check({boolean(), int32(), int8()}, {boolean(), int32(), int32()})); + ASSERT_OK(Check({boolean(), int32(), uint32()}, {boolean(), int64(), int64()})); + ASSERT_RAISES(Invalid, + Check({boolean(), utf8(), int32()}, {boolean(), int32(), int32()})); + ASSERT_RAISES(Invalid, + Check({int32(), int32(), int32()}, {boolean(), int32(), int32()})); +} Review comment: Somewhere, you should also test what happens when `case_when` is called with zero arguments. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, + Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs)); + if (datum_out.is_array()) { + std::shared_ptr<Array> result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr<Array> expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + + for (int64_t i = 0; i < result->length(); i++) { + // Check scalar + ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i)); + std::vector<Datum> inputs_scalar; + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + auto array = input.make_array(); + ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i)); + inputs_scalar.push_back(input_scalar); + } + } + ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(scalar_out.is_scalar()); + AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true); + + // Check slice + inputs_scalar.clear(); + auto expected_array = expected_->Slice(i); + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + inputs_scalar.push_back(input.make_array()->Slice(i)); + } + } + ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(array_out.is_array()); + AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true); + } + } else { + const std::shared_ptr<Scalar>& result = datum_out.scalar(); + const std::shared_ptr<Scalar>& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } +} + +template <typename Type> +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1, + const std::string& value2) { + auto scalar_true = ScalarFromJSON(boolean(), "true"); + auto scalar_false = ScalarFromJSON(boolean(), "false"); + auto scalar_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto value_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, value1); + auto scalar2 = ScalarFromJSON(type, value2); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + std::stringstream builder; + builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']'; + auto values1 = ArrayFromJSON(type, builder.str()); + builder.str(""); + builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']'; + auto values2 = ArrayFromJSON(type, builder.str()); + // N.B. all-scalar cases are checked in CheckCaseWhen + // Only an else array + CheckVarArgs("case_when", {values1}, values1); + // No else clause, scalar cond, array values + CheckVarArgs("case_when", {scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_false, values1}, values_null); + CheckVarArgs("case_when", {scalar_null, values1}, values_null); + CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1); + CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1); Review comment: Also `{scalar_false, values2, scalar_true, values1} -> values1`? ########## File path: cpp/src/arrow/compute/api_scalar.h ########## @@ -589,6 +589,21 @@ ARROW_EXPORT Result<Datum> IfElse(const Datum& cond, const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for +/// each row, select the first value for which the corresponding condition is +/// true, or (if given) select the 'else' value, else emit null. Review comment: So a condition being null is the same as being false, right? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, Review comment: I wonder if this should be moved to `kernels/test_util.{h,cc}` instead? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; Review comment: If the scalar isn't valid, we should zero-initialize the destination memory area. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, Review comment: Instead of `uint8_t* out_values`, you may want this to take a `ArrayData* out`, since you'll need it for non-fixed-width types? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, Review comment: `TypeError`? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); Review comment: This doesn't seem to do anything special apart from duplicating `type` accross `value_types`. Instead you can just reuse `type` below (call it `promoted_value_type`?). ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc ########## @@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs, + Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs)); + if (datum_out.is_array()) { + std::shared_ptr<Array> result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr<Array> expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + + for (int64_t i = 0; i < result->length(); i++) { + // Check scalar + ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i)); + std::vector<Datum> inputs_scalar; + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + auto array = input.make_array(); + ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i)); + inputs_scalar.push_back(input_scalar); + } + } + ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(scalar_out.is_scalar()); + AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true); + + // Check slice + inputs_scalar.clear(); + auto expected_array = expected_->Slice(i); + for (const auto& input : inputs) { + if (input.is_scalar()) { + inputs_scalar.push_back(input); + } else { + inputs_scalar.push_back(input.make_array()->Slice(i)); + } + } + ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar)); + ASSERT_TRUE(array_out.is_array()); + AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true); + } + } else { + const std::shared_ptr<Scalar>& result = datum_out.scalar(); + const std::shared_ptr<Scalar>& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } +} + +template <typename Type> +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1, + const std::string& value2) { + auto scalar_true = ScalarFromJSON(boolean(), "true"); + auto scalar_false = ScalarFromJSON(boolean(), "false"); + auto scalar_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto value_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, value1); + auto scalar2 = ScalarFromJSON(type, value2); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + std::stringstream builder; + builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']'; + auto values1 = ArrayFromJSON(type, builder.str()); + builder.str(""); + builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']'; + auto values2 = ArrayFromJSON(type, builder.str()); + // N.B. all-scalar cases are checked in CheckCaseWhen + // Only an else array + CheckVarArgs("case_when", {values1}, values1); + // No else clause, scalar cond, array values + CheckVarArgs("case_when", {scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_false, values1}, values_null); + CheckVarArgs("case_when", {scalar_null, values1}, values_null); + CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1); + CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1); + CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1); + // No else clause, array cond, scalar values + builder.str(""); + builder << '[' << value1 << ", null, null, null]"; + CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, value_null}, values_null); + builder.str(""); + builder << '[' << value1 << ", null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, builder.str())); + // No else clause, array cond, array values + builder.str(""); + builder << "[null, null, null, " << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, builder.str())); + // Else clauses/mixed scalar and array + builder.str(""); + builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']'; + CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1}, + ArrayFromJSON(type, builder.str())); + CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, builder.str())); +} + +TYPED_TEST(TestCaseWhenNumeric, FixedSize) { + auto type = default_type_instance<TypeParam>(); + CheckCaseWhenCases(type, "10", "42"); Review comment: Can we add a (perhaps hand-written) test with mixed values below? e.g. `("[true, false, null]", "[1, 2, 3]", "[false, null, true]", "[4, 5, 6]") -> "[1, null, 6]"` ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { + return Status::Invalid("Value arguments must be of same type, but argument ", i, + " was ", next_type->type->ToString(), " (expected ", + common_values_type.ToString(), ")"); + } + if (i == values->size() - 1) { + // ELSE + (*values)[i] = *next_type++; + } else { + (*values)[i + 1] = *next_type++; + } + } + + // We register a unary kernel for each value type and dispatch to it after validation. + if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar arguments +Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Scalar& cond = *batch[i].scalar(); + if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) { + *out = batch[i + 1]; + return Status::OK(); + } + } + if (batch.values.size() % 2 == 0) { + // No ELSE + *out = MakeNullScalar(batch[1].type()); + } else { + *out = batch.values.back(); + } + return Status::OK(); +} + +// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, +// given helper functions to copy data from a source array to a target array and to +// allocate a values buffer +template <typename Type> +Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ArrayData* output = out->mutable_array(); + const bool have_else_arg = batch.values.size() % 2 != 0; + // Check if we may need a validity bitmap + uint8_t* out_valid = nullptr; + + bool need_valid_bitmap = false; + if (!have_else_arg) { + // If we don't have an else arg -> need a bitmap since we may emit nulls + need_valid_bitmap = true; + } else if (batch.values.back().null_count() > 0) { + // If the 'else' array has a null count we need a validity bitmap + need_valid_bitmap = true; + } else { + // Otherwise if any value array has a null count we need a validity bitmap + for (size_t i = 1; i < batch.values.size(); i += 2) { + if (batch[i].null_count() > 0) { + need_valid_bitmap = true; + break; + } + } + } + if (need_valid_bitmap) { + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length)); + out_valid = output->buffers[0]->mutable_data(); + } + + // Initialize values buffer + uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { + // Copy 'else' value into output Review comment: This seems a bit excessive? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { + return Status::Invalid("Value arguments must be of same type, but argument ", i, + " was ", next_type->type->ToString(), " (expected ", + common_values_type.ToString(), ")"); + } + if (i == values->size() - 1) { + // ELSE + (*values)[i] = *next_type++; + } else { + (*values)[i + 1] = *next_type++; + } + } + + // We register a unary kernel for each value type and dispatch to it after validation. + if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar arguments +Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Scalar& cond = *batch[i].scalar(); + if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) { + *out = batch[i + 1]; + return Status::OK(); + } + } + if (batch.values.size() % 2 == 0) { + // No ELSE + *out = MakeNullScalar(batch[1].type()); + } else { + *out = batch.values.back(); + } + return Status::OK(); +} + +// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, +// given helper functions to copy data from a source array to a target array and to +// allocate a values buffer +template <typename Type> +Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ArrayData* output = out->mutable_array(); + const bool have_else_arg = batch.values.size() % 2 != 0; + // Check if we may need a validity bitmap + uint8_t* out_valid = nullptr; + + bool need_valid_bitmap = false; + if (!have_else_arg) { + // If we don't have an else arg -> need a bitmap since we may emit nulls + need_valid_bitmap = true; + } else if (batch.values.back().null_count() > 0) { + // If the 'else' array has a null count we need a validity bitmap + need_valid_bitmap = true; + } else { + // Otherwise if any value array has a null count we need a validity bitmap + for (size_t i = 1; i < batch.values.size(); i += 2) { + if (batch[i].null_count() > 0) { + need_valid_bitmap = true; + break; + } + } + } + if (need_valid_bitmap) { + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length)); + out_valid = output->buffers[0]->mutable_data(); + } + + // Initialize values buffer + uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { + // Copy 'else' value into output + CopyValues<Type>(batch.values.back(), out_valid, out_values, /*offset=*/0, + batch.length); + } else if (need_valid_bitmap) { + // There's no 'else' argument, so we should have an all-null validity bitmap + std::memset(out_valid, 0x00, output->buffers[0]->size()); + } + + // Allocate a temporary bitmap to determine which elements still need setting. + ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length)); + uint8_t* mask = mask_buffer->mutable_data(); + std::memset(mask, 0xFF, mask_buffer->size()); + // Then iterate through each argument in turn and set elements. + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Datum& cond_datum = batch[i]; + const Datum& values_datum = batch[i + 1]; + if (cond_datum.is_scalar()) { + const Scalar& cond_scalar = *cond_datum.scalar(); + const bool cond = + cond_scalar.is_valid && UnboxScalar<BooleanType>::Unbox(cond_scalar); + if (!cond) continue; + BitBlockCounter counter(mask, /*start_offset=*/0, batch.length); + int64_t offset = 0; + while (offset < batch.length) { + const auto block = counter.NextWord(); + if (block.AllSet()) { + CopyValues<Type>(values_datum, out_valid, out_values, offset, block.length); + } else if (block.popcount) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(mask, offset + j)) { + CopyValues<Type>(values_datum, out_valid, out_values, offset + j, + /*length=*/1); + } + } + } + offset += block.length; + } + break; + } + + const ArrayData& cond_array = *cond_datum.array(); + const uint8_t* cond_values = cond_array.buffers[1]->data(); + int64_t offset = 0; + // If no valid buffer, visit mask & value bitmap simultaneously + if (!cond_array.MayHaveNulls()) { Review comment: You can use `GetNullCount()` I think? ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { Review comment: Given the definition of `ReplaceTypes`, this check doesn't seem necessary? cc @bkietz for a second opinion. ########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); Review comment: At some point, we should perhaps try to find out whether `std::copy` is as performant as `memcpy`. cc @bkietz for opinions. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org