This is an automated email from the ASF dual-hosted git repository. kszucs pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new c822e46 ARROW-3087: [C++] Implement Compare filter kernel c822e46 is described below commit c822e466b6e88f67059fea2da4a9fb5e5bfa201a Author: François Saint-Jacques <fsaintjacq...@gmail.com> AuthorDate: Mon Apr 15 13:44:02 2019 +0200 ARROW-3087: [C++] Implement Compare filter kernel This is the first step in supporting basic expressions of the form `column_x > k`. Author: François Saint-Jacques <fsaintjacq...@gmail.com> Closes #3963 from fsaintjacques/ARROW-3087-filter-kernel and squashes the following commits: 8a745d826 <François Saint-Jacques> Use ArrayData instead of Array e018e8f9d <François Saint-Jacques> Various comments ada18ec96 <François Saint-Jacques> Use GenerateBitmapUnrolled 35695b516 <François Saint-Jacques> Remove extra includes e7e4a8e3f <François Saint-Jacques> Rename comparator.h to compare.h 63cd9ca80 <François Saint-Jacques> Address comments 0d0ebe2b7 <François Saint-Jacques> Address comments and warnings b3a2520c6 <François Saint-Jacques> Make lint happy f8e16f0cc <François Saint-Jacques> Address comments c8315fa9a <François Saint-Jacques> ARROW-3087: Implement Compare filter kernel f6b4be1f4 <François Saint-Jacques> Remove type and length DCHECK in PrimitiveAllocatingBinaryKernel d54c8141b <François Saint-Jacques> Add CType support for Date/Time types in TypeTraits 5b592b44f <François Saint-Jacques> Add integer casting for JSON bool array in ArrayFromJSON c9331de8b <François Saint-Jacques> Move compute benchmark utils in header --- cpp/src/arrow/CMakeLists.txt | 2 + cpp/src/arrow/compute/benchmark-util.h | 59 ++++++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 4 + .../arrow/compute/kernels/aggregate-benchmark.cc | 47 +---- cpp/src/arrow/compute/kernels/compare.cc | 175 ++++++++++++++++ cpp/src/arrow/compute/kernels/compare.h | 116 +++++++++++ cpp/src/arrow/compute/kernels/filter-benchmark.cc | 57 ++++++ cpp/src/arrow/compute/kernels/filter-test.cc | 228 +++++++++++++++++++++ cpp/src/arrow/compute/kernels/filter.cc | 41 ++++ cpp/src/arrow/compute/kernels/filter.h | 67 ++++++ .../arrow/compute/kernels/util-internal-test.cc | 15 ++ cpp/src/arrow/compute/kernels/util-internal.cc | 22 +- cpp/src/arrow/compute/kernels/util-internal.h | 8 + cpp/src/arrow/ipc/json-simple-test.cc | 6 +- cpp/src/arrow/ipc/json-simple.cc | 3 + cpp/src/arrow/type_traits.h | 9 +- 16 files changed, 816 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c045704..6dac9ee 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -154,7 +154,9 @@ if(ARROW_COMPUTE) compute/kernels/aggregate.cc compute/kernels/boolean.cc compute/kernels/cast.cc + compute/kernels/compare.cc compute/kernels/count.cc + compute/kernels/filter.cc compute/kernels/hash.cc compute/kernels/mean.cc compute/kernels/sum.cc diff --git a/cpp/src/arrow/compute/benchmark-util.h b/cpp/src/arrow/compute/benchmark-util.h new file mode 100644 index 0000000..1678f8d --- /dev/null +++ b/cpp/src/arrow/compute/benchmark-util.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/cpu-info.h" + +namespace arrow { +namespace compute { + +using internal::CpuInfo; +static CpuInfo* cpu_info = CpuInfo::GetInstance(); + +static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE); +static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE); +static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE); +static const int64_t kCantFitInL3Size = kL3Size * 4; + +template <typename Func> +struct BenchmarkArgsType; + +template <typename Values> +struct BenchmarkArgsType<benchmark::internal::Benchmark* ( + benchmark::internal::Benchmark::*)(const std::vector<Values>&)> { + using type = Values; +}; + +void BenchmarkSetArgs(benchmark::internal::Benchmark* bench) { + // Benchmark changed its parameter type between releases from + // int to int64_t. As it doesn't have version macros, we need + // to apply C++ template magic. + using ArgsType = + typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type; + bench->Unit(benchmark::kMicrosecond); + + for (auto size : {kL1Size, kL2Size, kL3Size, kCantFitInL3Size}) + for (auto nulls : std::vector<ArgsType>({0, 1, 10, 50})) + bench->Args({static_cast<ArgsType>(size), nulls}); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index abdc092..6c386c9 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -26,3 +26,7 @@ add_arrow_test(util-internal-test PREFIX "arrow-compute") # Aggregates add_arrow_test(aggregate-test PREFIX "arrow-compute") add_arrow_benchmark(aggregate-benchmark PREFIX "arrow-compute") + +# Filters +add_arrow_test(filter-test PREFIX "arrow-compute") +add_arrow_benchmark(filter-benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/aggregate-benchmark.cc b/cpp/src/arrow/compute/kernels/aggregate-benchmark.cc index b533b91..e81f879 100644 --- a/cpp/src/arrow/compute/kernels/aggregate-benchmark.cc +++ b/cpp/src/arrow/compute/kernels/aggregate-benchmark.cc @@ -29,8 +29,8 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/util/bit-util.h" -#include "arrow/util/cpu-info.h" +#include "arrow/compute/benchmark-util.h" #include "arrow/compute/context.h" #include "arrow/compute/kernel.h" #include "arrow/compute/kernels/sum.h" @@ -43,13 +43,6 @@ namespace compute { #include <iostream> #include <random> -using internal::CpuInfo; -static CpuInfo* cpu_info = CpuInfo::GetInstance(); - -static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE); -static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE); -static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE); - namespace BitUtil = arrow::BitUtil; using arrow::internal::BitmapReader; @@ -308,35 +301,13 @@ void BenchSum(benchmark::State& state) { state.SetBytesProcessed(state.iterations() * array_size * sizeof(T)); } -template <typename Func> -struct BenchmarkArgsType; - -template <typename Values> -struct BenchmarkArgsType<benchmark::internal::Benchmark* ( - benchmark::internal::Benchmark::*)(const std::vector<Values>&)> { - using type = Values; -}; - -static void SetArgs(benchmark::internal::Benchmark* bench) { - // Benchmark changed its parameter type between releases from - // int to int64_t. As it doesn't have version macros, we need - // to apply C++ template magic. - using ArgsType = - typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type; - bench->Unit(benchmark::kMicrosecond); - - for (auto size : {kL1Size, kL2Size, kL3Size, kL3Size * 4}) - for (auto nulls : std::vector<ArgsType>({0, 1, 10, 50})) - bench->Args({static_cast<ArgsType>(size), nulls}); -} - -BENCHMARK_TEMPLATE(BenchSum, SumNoNulls<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumNoNullsUnrolled<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumSentinel<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumSentinelUnrolled<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumBitmapNaive<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumBitmapReader<int64_t>)->Apply(SetArgs); -BENCHMARK_TEMPLATE(BenchSum, SumBitmapVectorizeUnroll<int64_t>)->Apply(SetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumNoNulls<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumNoNullsUnrolled<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumSentinel<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumSentinelUnrolled<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumBitmapNaive<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumBitmapReader<int64_t>)->Apply(BenchmarkSetArgs); +BENCHMARK_TEMPLATE(BenchSum, SumBitmapVectorizeUnroll<int64_t>)->Apply(BenchmarkSetArgs); static void BenchSumKernel(benchmark::State& state) { const int64_t array_size = state.range(0) / sizeof(int64_t); @@ -357,7 +328,7 @@ static void BenchSumKernel(benchmark::State& state) { state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); } -BENCHMARK(BenchSumKernel)->Apply(SetArgs); +BENCHMARK(BenchSumKernel)->Apply(BenchmarkSetArgs); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc new file mode 100644 index 0000000..10ef766 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare.cc @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/compare.h" + +#include "arrow/compute/context.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/filter.h" +#include "arrow/compute/kernels/util-internal.h" +#include "arrow/util/bit-util.h" +#include "arrow/util/logging.h" + +namespace arrow { + +namespace compute { + +class FunctionContext; +struct Datum; + +template <typename ArrowType, CompareOperator Op, + typename ArrayType = typename TypeTraits<ArrowType>::ArrayType, + typename ScalarType = typename TypeTraits<ArrowType>::ScalarType, + typename T = typename TypeTraits<ArrowType>::CType> +static Status CompareArrayScalar(const ArrayData& input, const ScalarType& scalar, + uint8_t* bitmap) { + const T right = scalar.value; + const T* values = input.GetValues<T>(1); + + size_t i = 0; + internal::GenerateBitsUnrolled(bitmap, 0, input.length, [values, right, &i]() -> bool { + return Comparator<T, Op>::Compare(values[i++], right); + }); + + return Status::OK(); +} + +template <typename ArrowType, CompareOperator Op> +class CompareFunction final : public FilterFunction { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ScalarType = typename TypeTraits<ArrowType>::ScalarType; + + public: + explicit CompareFunction(FunctionContext* ctx) : ctx_(ctx) {} + + Status Filter(const ArrayData& input, const Scalar& scalar, ArrayData* output) const { + // Caller must cast + DCHECK(input.type->Equals(scalar.type)); + // Output must be a boolean array + DCHECK(output->type->Equals(boolean())); + // Output must be of same length + DCHECK_EQ(output->length, input.length); + + // Scalar is null, all comparisons are null. + if (!scalar.is_valid) { + return detail::SetAllNulls(ctx_, input, output); + } + + // Copy null_bitmap + RETURN_NOT_OK(detail::PropagateNulls(ctx_, input, output)); + + uint8_t* bitmap_result = output->buffers[1]->mutable_data(); + return CompareArrayScalar<ArrowType, Op>( + input, static_cast<const ScalarType&>(scalar), bitmap_result); + } + + private: + FunctionContext* ctx_; +}; + +template <typename ArrowType, CompareOperator Op> +static inline std::shared_ptr<FilterFunction> MakeCompareFunctionTypeOp( + FunctionContext* ctx) { + return std::make_shared<CompareFunction<ArrowType, Op>>(ctx); +} + +template <typename ArrowType> +static inline std::shared_ptr<FilterFunction> MakeCompareFilterFunctionType( + FunctionContext* ctx, struct CompareOptions options) { + switch (options.op) { + case CompareOperator::EQUAL: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::EQUAL>(ctx); + case CompareOperator::NOT_EQUAL: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::NOT_EQUAL>(ctx); + case CompareOperator::GREATER: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::GREATER>(ctx); + case CompareOperator::GREATER_EQUAL: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::GREATER_EQUAL>(ctx); + case CompareOperator::LOWER: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::LOWER>(ctx); + case CompareOperator::LOWER_EQUAL: + return MakeCompareFunctionTypeOp<ArrowType, CompareOperator::LOWER_EQUAL>(ctx); + } + + return nullptr; +} + +std::shared_ptr<FilterFunction> MakeCompareFilterFunction(FunctionContext* ctx, + const DataType& type, + struct CompareOptions options) { + switch (type.id()) { + case UInt8Type::type_id: + return MakeCompareFilterFunctionType<UInt8Type>(ctx, options); + case Int8Type::type_id: + return MakeCompareFilterFunctionType<Int8Type>(ctx, options); + case UInt16Type::type_id: + return MakeCompareFilterFunctionType<UInt16Type>(ctx, options); + case Int16Type::type_id: + return MakeCompareFilterFunctionType<Int16Type>(ctx, options); + case UInt32Type::type_id: + return MakeCompareFilterFunctionType<UInt32Type>(ctx, options); + case Int32Type::type_id: + return MakeCompareFilterFunctionType<Int32Type>(ctx, options); + case UInt64Type::type_id: + return MakeCompareFilterFunctionType<UInt64Type>(ctx, options); + case Int64Type::type_id: + return MakeCompareFilterFunctionType<Int64Type>(ctx, options); + case FloatType::type_id: + return MakeCompareFilterFunctionType<FloatType>(ctx, options); + case DoubleType::type_id: + return MakeCompareFilterFunctionType<DoubleType>(ctx, options); + case Date32Type::type_id: + return MakeCompareFilterFunctionType<Date32Type>(ctx, options); + case Date64Type::type_id: + return MakeCompareFilterFunctionType<Date64Type>(ctx, options); + case TimestampType::type_id: + return MakeCompareFilterFunctionType<TimestampType>(ctx, options); + case Time32Type::type_id: + return MakeCompareFilterFunctionType<Time32Type>(ctx, options); + case Time64Type::type_id: + return MakeCompareFilterFunctionType<Time64Type>(ctx, options); + default: + return nullptr; + } +} + +ARROW_EXPORT +Status Compare(FunctionContext* context, const Datum& left, const Datum& right, + struct CompareOptions options, Datum* out) { + DCHECK(out); + + DCHECK_EQ(left.kind(), Datum::ARRAY); + DCHECK_EQ(right.kind(), Datum::SCALAR); + DCHECK(left.type()->Equals(right.type())); + + auto array = left.make_array(); + auto type = array->type(); + + auto fn = MakeCompareFilterFunction(context, *type, options); + if (fn == nullptr) { + return Status::NotImplemented("Compare not implemented for type ", type->ToString()); + } + + FilterBinaryKernel filter_kernel(fn); + detail::PrimitiveAllocatingBinaryKernel kernel(&filter_kernel); + out->value = ArrayData::Make(filter_kernel.out_type(), array->length()); + + return kernel.Call(context, left, right, out); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare.h b/cpp/src/arrow/compute/kernels/compare.h new file mode 100644 index 0000000..7f6b299 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare.h @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; +class DataType; +struct Scalar; +class Status; + +namespace compute { + +struct Datum; +class FilterFunction; +class FunctionContext; + +enum CompareOperator { + EQUAL, + NOT_EQUAL, + GREATER, + GREATER_EQUAL, + LOWER, + LOWER_EQUAL, +}; + +template <typename T, CompareOperator Op> +struct Comparator; + +template <typename T> +struct Comparator<T, CompareOperator::EQUAL> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs == rhs; } +}; + +template <typename T> +struct Comparator<T, CompareOperator::NOT_EQUAL> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs != rhs; } +}; + +template <typename T> +struct Comparator<T, CompareOperator::GREATER> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs > rhs; } +}; + +template <typename T> +struct Comparator<T, CompareOperator::GREATER_EQUAL> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs >= rhs; } +}; + +template <typename T> +struct Comparator<T, CompareOperator::LOWER> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs < rhs; } +}; + +template <typename T> +struct Comparator<T, CompareOperator::LOWER_EQUAL> { + constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs <= rhs; } +}; + +struct CompareOptions { + explicit CompareOptions(CompareOperator op) : op(op) {} + + enum CompareOperator op; +}; + +/// \brief Return a Compare FilterFunction +/// +/// \param[in] context FunctionContext passing context information +/// \param[in] type required to specialize the kernel +/// \param[in] options required to specify the compare operator +/// +/// \since 0.13.0 +/// \note API not yet finalized +ARROW_EXPORT +std::shared_ptr<FilterFunction> MakeCompareFilterFunction(FunctionContext* context, + const DataType& type, + struct CompareOptions options); + +/// \brief Compare a numeric array with a scalar. +/// +/// \param[in] context the FunctionContext +/// \param[in] left datum to compare, must be an Array +/// \param[in] right datum to compare, must be a Scalar of the same type than +/// left Datum. +/// \param[in] options compare options +/// \param[out] out resulting datum +/// +/// Note on floating point arrays, this uses ieee-754 compare semantics. +/// +/// \since 0.13.0 +/// \note API not yet finalized +ARROW_EXPORT +Status Compare(FunctionContext* context, const Datum& left, const Datum& right, + struct CompareOptions options, Datum* out); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc new file mode 100644 index 0000000..3826e26 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include <vector> + +#include "arrow/compute/benchmark-util.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +static void BenchCompareKernel(benchmark::State& state) { + const int64_t memory_size = state.range(0) / 4; + const int64_t array_size = memory_size / sizeof(int64_t); + const double null_percent = static_cast<double>(state.range(1)) / 100.0; + auto rand = random::RandomArrayGenerator(0x94378165); + auto array = std::static_pointer_cast<NumericArray<Int64Type>>( + rand.Int64(array_size, -100, 100, null_percent)); + + CompareOptions ge(GREATER_EQUAL); + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Compare(&ctx, Datum(array), Datum(int64_t(0)), ge, &out)); + benchmark::DoNotOptimize(out); + } + + state.counters["size"] = static_cast<double>(memory_size); + state.counters["null_percent"] = static_cast<double>(state.range(1)); + state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); +} + +BENCHMARK(BenchCompareKernel)->Apply(BenchmarkSetArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc new file mode 100644 index 0000000..1e35ee9 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/filter-test.cc @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> + +#include <gtest/gtest.h> + +#include "arrow/array.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/kernels/filter.h" +#include "arrow/compute/test-util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" + +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +TEST(TestComparatorOperator, BasicOperator) { + using T = int32_t; + std::vector<T> vals{0, 1, 2, 3, 4, 5, 6}; + + for (int32_t i : vals) { + for (int32_t j : vals) { + EXPECT_EQ((Comparator<T, EQUAL>::Compare(i, j)), i == j); + EXPECT_EQ((Comparator<T, NOT_EQUAL>::Compare(i, j)), i != j); + EXPECT_EQ((Comparator<T, GREATER>::Compare(i, j)), i > j); + EXPECT_EQ((Comparator<T, GREATER_EQUAL>::Compare(i, j)), i >= j); + EXPECT_EQ((Comparator<T, LOWER>::Compare(i, j)), i < j); + EXPECT_EQ((Comparator<T, LOWER_EQUAL>::Compare(i, j)), i <= j); + } + } +} + +template <typename ArrowType> +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs, const Datum& expected) { + Datum result; + + ASSERT_OK(Compare(ctx, lhs, rhs, options, &result)); + AssertArraysEqual(*expected.make_array(), *result.make_array()); +} + +template <typename ArrowType> +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const char* lhs_str, const Datum& rhs, + const char* expected_str) { + auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str); + auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str); + ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected); +} + +template <typename T> +static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { + switch (op) { + case EQUAL: + return lhs == rhs; + case NOT_EQUAL: + return lhs != rhs; + case GREATER: + return lhs > rhs; + case GREATER_EQUAL: + return lhs >= rhs; + case LOWER: + return lhs < rhs; + case LOWER_EQUAL: + return lhs <= rhs; + default: + return false; + } +} + +template <typename ArrowType> +static Datum SimpleCompare(CompareOptions options, const Datum& lhs, const Datum& rhs) { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ScalarType = typename TypeTraits<ArrowType>::ScalarType; + using T = typename TypeTraits<ArrowType>::CType; + + auto array = std::static_pointer_cast<ArrayType>(lhs.make_array()); + T value = std::static_pointer_cast<ScalarType>(rhs.scalar())->value; + + std::vector<bool> bitmap(array->length()); + for (int64_t i = 0; i < array->length(); i++) { + bitmap[i] = SlowCompare<T>(options.op, array->Value(i), value); + } + + std::shared_ptr<Array> result; + + if (array->null_count() == 0) { + ArrayFromVector<BooleanType>(bitmap, &result); + } else { + std::vector<bool> null_bitmap(array->length()); + auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), + array->length()); + for (int64_t i = 0; i < array->length(); i++, reader.Next()) + null_bitmap[i] = reader.IsSet(); + ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result); + } + + return Datum(result); +} + +template <typename ArrowType> +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs) { + Datum result; + Datum expected = SimpleCompare<ArrowType>(options, lhs, rhs); + + ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected); +} + +template <typename ArrowType> +class TestNumericCompareKernel : public ComputeFixture, public TestBase {}; + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { + using ScalarType = typename TypeTraits<TypeParam>::ScalarType; + using CType = typename TypeTraits<TypeParam>::CType; + + Datum one(std::make_shared<ScalarType>(CType(1))); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions neq(CompareOperator::NOT_EQUAL); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); + ValidateCompare<TypeParam>(&this->ctx_, neq, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions gt(CompareOperator::GREATER); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gt, "[null,0,1,1]", one, "[null,0,0,0]"); + + CompareOptions gte(CompareOperator::GREATER_EQUAL); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare<TypeParam>(&this->ctx_, gte, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions lt(CompareOperator::LOWER); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lt, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions lte(CompareOperator::LOWER_EQUAL); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[]", one, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[null]", one, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare<TypeParam>(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { + /* Ensure that null scalar broadcast to all null results. */ + using ScalarType = typename TypeTraits<TypeParam>::ScalarType; + using CType = typename TypeTraits<TypeParam>::CType; + + Datum null(std::make_shared<ScalarType>(CType(0), false)); + EXPECT_FALSE(null.scalar()->is_valid); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[]", null, "[]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[null]", null, "[null]"); + ValidateCompare<TypeParam>(&this->ctx_, eq, "[1,2,3]", null, "[null, null, null]"); +} + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) { + using ScalarType = typename TypeTraits<TypeParam>::ScalarType; + using CType = typename TypeTraits<TypeParam>::CType; + + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto length_adjust : {-2, -1, 0, 1, 2}) { + int64_t length = (1UL << i) + length_adjust; + auto array = Datum(rand.Numeric<TypeParam>(length, 0, 100, null_probability)); + auto zero = Datum(std::make_shared<ScalarType>(CType(50))); + auto options = CompareOptions(GREATER); + ValidateCompare<TypeParam>(&this->ctx_, options, array, zero); + } + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc new file mode 100644 index 0000000..d7fbf54 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/filter.h" + +#include "arrow/array.h" +#include "arrow/compute/kernel.h" + +namespace arrow { + +namespace compute { + +std::shared_ptr<DataType> FilterBinaryKernel::out_type() const { + return filter_function_->out_type(); +} + +Status FilterBinaryKernel::Call(FunctionContext* ctx, const Datum& left, + const Datum& right, Datum* out) { + auto array = left.array(); + auto scalar = right.scalar(); + auto result = out->array(); + + return filter_function_->Filter(*array, *scalar, result.get()); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h new file mode 100644 index 0000000..becd2d5 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/filter.h @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/compute/kernel.h" + +namespace arrow { + +class Array; +struct Scalar; +class Status; + +namespace compute { + +class FunctionContext; +struct Datum; + +/// FilterFunction is an interface for Filters +/// +/// Filters takes an array and emits a selection vector. The selection vector +/// is given in the form of a bitmask as a BooleanArray result. +class ARROW_EXPORT FilterFunction { + public: + /// Filter an array with a scalar argument. + virtual Status Filter(const ArrayData& input, const Scalar& scalar, + ArrayData* output) const = 0; + + /// By default, FilterFunction emits a result bitmap. + virtual std::shared_ptr<DataType> out_type() const { return boolean(); } + + virtual ~FilterFunction() {} +}; + +/// \brief BinaryKernel bound to a filter function +class ARROW_EXPORT FilterBinaryKernel : public BinaryKernel { + public: + explicit FilterBinaryKernel(std::shared_ptr<FilterFunction>& filter) + : filter_function_(filter) {} + + Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, + Datum* out) override; + + std::shared_ptr<DataType> out_type() const override; + + private: + std::shared_ptr<FilterFunction> filter_function_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util-internal-test.cc b/cpp/src/arrow/compute/kernels/util-internal-test.cc index 8ce3569..5140543 100644 --- a/cpp/src/arrow/compute/kernels/util-internal-test.cc +++ b/cpp/src/arrow/compute/kernels/util-internal-test.cc @@ -96,6 +96,21 @@ TEST(PropagateNulls, OffsetAndHasNulls) { Each(0)); } +TEST(SetAllNulls, Basic) { + const int64_t length = 16; + ArrayData input(boolean(), length); + FunctionContext ctx(default_memory_pool()); + ArrayData output; + + ASSERT_OK(SetAllNulls(&ctx, input, &output)); + ASSERT_THAT(output.null_count, Eq(length)); + + const auto& output_buffer = *output.buffers[0]; + ASSERT_THAT(std::vector<uint8_t>(output_buffer.data(), + output_buffer.data() + output_buffer.size()), + Each(0)); +} + TEST(AssignNullIntersection, ZeroCopyWhenZeroNullsOnOneInput) { ArrayData some_nulls(boolean(), /* length= */ 16, kUnknownNullCount); constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0}; diff --git a/cpp/src/arrow/compute/kernels/util-internal.cc b/cpp/src/arrow/compute/kernels/util-internal.cc index 174b688..2f94407 100644 --- a/cpp/src/arrow/compute/kernels/util-internal.cc +++ b/cpp/src/arrow/compute/kernels/util-internal.cc @@ -236,6 +236,26 @@ Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* o return Status::OK(); } +Status SetAllNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output) { + const int64_t length = input.length; + if (output->buffers.size() == 0) { + // Ensure we can assign a buffer + output->buffers.resize(1); + } + + // Handle validity bitmap + if (output->buffers[0] == nullptr) { + std::shared_ptr<Buffer> buffer; + RETURN_NOT_OK(ctx->Allocate(BitUtil::BytesForBits(length), &buffer)); + output->buffers[0] = std::move(buffer); + } + + output->null_count = length; + BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), 0, length, false); + + return Status::OK(); +} + Status AssignNullIntersection(FunctionContext* ctx, const ArrayData& left, const ArrayData& right, ArrayData* output) { if (output->buffers.size() == 0) { @@ -288,9 +308,7 @@ Status PrimitiveAllocatingBinaryKernel::Call(FunctionContext* ctx, const Datum& const Datum& right, Datum* out) { std::vector<std::shared_ptr<Buffer>> data_buffers; DCHECK_EQ(left.kind(), Datum::ARRAY); - DCHECK_EQ(right.kind(), Datum::ARRAY); const ArrayData& left_data = *left.array(); - DCHECK_EQ(left_data.length, right.array()->length); DCHECK_EQ(out->kind(), Datum::ARRAY); diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h index 4cb7a24..25a670c 100644 --- a/cpp/src/arrow/compute/kernels/util-internal.h +++ b/cpp/src/arrow/compute/kernels/util-internal.h @@ -72,6 +72,14 @@ Status InvokeBinaryArrayKernel(FunctionContext* ctx, BinaryKernel* kernel, ARROW_EXPORT Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output); +/// \brief Set validity bitmap in output with all null values. +/// +/// \param[in] ctx the kernel FunctionContext +/// \param[in] input the input array +/// \param[out] output the output array. Must have length and buffer set correctly. +ARROW_EXPORT +Status SetAllNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output); + /// \brief Assign validity bitmap to output, taking the intersection of left and right /// null bitmaps if necessary, but zero-copy otherwise. /// diff --git a/cpp/src/arrow/ipc/json-simple-test.cc b/cpp/src/arrow/ipc/json-simple-test.cc index 238061f..1bb04a3 100644 --- a/cpp/src/arrow/ipc/json-simple-test.cc +++ b/cpp/src/arrow/ipc/json-simple-test.cc @@ -232,13 +232,17 @@ TEST(TestBoolean, Basics) { AssertJSONArray<BooleanType, bool>(type, "[false, true, false]", {false, true, false}); AssertJSONArray<BooleanType, bool>(type, "[false, true, null]", {true, true, false}, {false, true, false}); + // Supports integer literal casting + AssertJSONArray<BooleanType, bool>(type, "[0, 1, 0]", {false, true, false}); + AssertJSONArray<BooleanType, bool>(type, "[0, 1, null]", {true, true, false}, + {false, true, false}); } TEST(TestBoolean, Errors) { std::shared_ptr<DataType> type = boolean(); std::shared_ptr<Array> array; - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array)); + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0.0]", &array)); ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"true\"]", &array)); } diff --git a/cpp/src/arrow/ipc/json-simple.cc b/cpp/src/arrow/ipc/json-simple.cc index 047788c..2861bd7 100644 --- a/cpp/src/arrow/ipc/json-simple.cc +++ b/cpp/src/arrow/ipc/json-simple.cc @@ -137,6 +137,9 @@ class BooleanConverter final : public ConcreteConverter<BooleanConverter> { if (json_obj.IsBool()) { return builder_->Append(json_obj.GetBool()); } + if (json_obj.IsInt()) { + return builder_->Append(json_obj.GetInt() != 0); + } return JSONTypeError("boolean", json_obj.GetType()); } diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 263916f..a8d6214 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -76,9 +76,9 @@ struct CTypeTraits<bool> : public TypeTraits<BooleanType> { using BuilderType = ArrowBuilderType; \ using ScalarType = ArrowScalarType; \ using TensorType = ArrowTensorType; \ - using CType = CType_; \ + using CType = ArrowType_::c_type; \ static constexpr int64_t bytes_required(int64_t elements) { \ - return elements * static_cast<int64_t>(sizeof(CType_)); \ + return elements * static_cast<int64_t>(sizeof(CType)); \ } \ constexpr static bool is_parameter_free = true; \ static inline std::shared_ptr<DataType> type_singleton() { return SingletonFn(); } \ @@ -114,6 +114,7 @@ struct TypeTraits<Date64Type> { using ArrayType = Date64Array; using BuilderType = Date64Builder; using ScalarType = Date64Scalar; + using CType = Date64Type::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast<int64_t>(sizeof(int64_t)); @@ -127,6 +128,7 @@ struct TypeTraits<Date32Type> { using ArrayType = Date32Array; using BuilderType = Date32Builder; using ScalarType = Date32Scalar; + using CType = Date32Type::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast<int64_t>(sizeof(int32_t)); @@ -140,6 +142,7 @@ struct TypeTraits<TimestampType> { using ArrayType = TimestampArray; using BuilderType = TimestampBuilder; using ScalarType = TimestampScalar; + using CType = TimestampType::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast<int64_t>(sizeof(int64_t)); @@ -152,6 +155,7 @@ struct TypeTraits<Time32Type> { using ArrayType = Time32Array; using BuilderType = Time32Builder; using ScalarType = Time32Scalar; + using CType = Time32Type::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast<int64_t>(sizeof(int32_t)); @@ -164,6 +168,7 @@ struct TypeTraits<Time64Type> { using ArrayType = Time64Array; using BuilderType = Time64Builder; using ScalarType = Time64Scalar; + using CType = Time64Type::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast<int64_t>(sizeof(int64_t));