wesm commented on a change in pull request #7382: URL: https://github.com/apache/arrow/pull/7382#discussion_r437580946
########## File path: cpp/src/arrow/compute/kernels/vector_take.cc ########## @@ -38,44 +55,715 @@ std::unique_ptr<KernelState> InitTake(KernelContext*, const KernelInitArgs& args return std::unique_ptr<KernelState>(new TakeState{*take_options}); } -template <typename ValueType, typename IndexType> -struct TakeFunctor { - using ValueArrayType = typename TypeTraits<ValueType>::ArrayType; - using IndexArrayType = typename TypeTraits<IndexType>::ArrayType; - using IS = ArrayIndexSequence<IndexType>; +namespace {} // namespace + +// ---------------------------------------------------------------------- +// Implement optimized take for primitive types from boolean to 1/2/4/8-byte +// C-type based types. Use common implementation for every byte width and only +// generate code for unsigned integer indices, since after boundschecking to +// check for negative numbers the indices we can safely reinterpret_cast signed +// integers as unsigned. + +struct PrimitiveTakeArgs { + const uint8_t* values; + const uint8_t* values_bitmap = nullptr; + int values_bit_width; + int64_t values_length; + int64_t values_offset; + int64_t values_null_count; + const uint8_t* indices; + const uint8_t* indices_bitmap = nullptr; + int indices_bit_width; + int64_t indices_length; + int64_t indices_offset; + int64_t indices_null_count; +}; + +// Reduce code size by dealing with the unboxing of the kernel inputs once +// rather than duplicating compiled code to do all these in each kernel. +PrimitiveTakeArgs GetPrimitiveTakeArgs(const ExecBatch& batch) { + PrimitiveTakeArgs args; + + const ArrayData& arg0 = *batch[0].array(); + const ArrayData& arg1 = *batch[1].array(); + + // Values + args.values_bit_width = static_cast<const FixedWidthType&>(*arg0.type).bit_width(); + args.values = arg0.buffers[1]->data(); + if (args.values_bit_width > 1) { + args.values += arg0.offset * args.values_bit_width / 8; + } + args.values_length = arg0.length; + args.values_offset = arg0.offset; + args.values_null_count = arg0.GetNullCount(); + if (arg0.buffers[0]) { + args.values_bitmap = arg0.buffers[0]->data(); + } + + // Indices + args.indices_bit_width = static_cast<const FixedWidthType&>(*arg1.type).bit_width(); + args.indices = arg1.buffers[1]->data() + arg1.offset * args.indices_bit_width / 8; + args.indices_length = arg1.length; + args.indices_offset = arg1.offset; + args.indices_null_count = arg1.GetNullCount(); + if (arg1.buffers[0]) { + args.indices_bitmap = arg1.buffers[0]->data(); + } + + return args; +} + +/// \brief The Take implementation for primitive (fixed-width) types does not +/// use the logical Arrow type but rather then physical C type. This way we +/// only generate one take function for each byte width. +/// +/// This function assumes that the indices have been boundschecked. +template <typename IndexCType, typename ValueCType> +struct PrimitiveTakeImpl { + static void Exec(const PrimitiveTakeArgs& args, Datum* out_datum) { + auto values = reinterpret_cast<const ValueCType*>(args.values); + auto values_bitmap = args.values_bitmap; + auto values_offset = args.values_offset; + + auto indices = reinterpret_cast<const IndexCType*>(args.indices); + auto indices_bitmap = args.indices_bitmap; + auto indices_offset = args.indices_offset; + + ArrayData* out_arr = out_datum->mutable_array(); + auto out = out_arr->GetMutableValues<ValueCType>(1); + auto out_bitmap = out_arr->buffers[0]->mutable_data(); + auto out_offset = out_arr->offset; + + // If either the values or indices have nulls, we preemptively zero out the + // out validity bitmap so that we don't have to use ClearBit in each + // iteration for nulls. + if (args.values_null_count > 0 || args.indices_null_count > 0) { + BitUtil::SetBitsTo(out_bitmap, out_offset, args.indices_length, false); + } + + OptionalBitBlockCounter indices_bit_counter(indices_bitmap, indices_offset, + args.indices_length); + int64_t position = 0; + int64_t valid_count = 0; + while (true) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (block.length == 0) { + // All indices processed. + break; + } + if (args.values_null_count == 0) { + // Values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither values nor index nulls + BitUtil::SetBitsTo(out_bitmap, out_offset + position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + out[position] = values[indices[position]]; + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_bitmap, indices_offset + position)) { + // index is not null + BitUtil::SetBit(out_bitmap, out_offset + position); + out[position] = values[indices[position]]; + } + ++position; + } + } + } else { + // Values have nulls, so we must do random access into the values bitmap + if (block.popcount == block.length) { + // Faster path: indices are not null but values may be + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(values_bitmap, values_offset + indices[position])) { + // value is not null + out[position] = values[indices[position]]; + BitUtil::SetBit(out_bitmap, out_offset + position); + ++valid_count; + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_bitmap, indices_offset + position)) { + // index is not null + if (BitUtil::GetBit(values_bitmap, values_offset + indices[position])) { + // value is not null + out[position] = values[indices[position]]; + BitUtil::SetBit(out_bitmap, out_offset + position); + ++valid_count; + } + } + ++position; + } + } + } + } + out_arr->null_count = out_arr->length - valid_count; + } +}; + +template <typename IndexCType> +struct BooleanTakeImpl { Review comment: TODO: I will add some random data unit tests for boolean type, which are only sparsely tested in the test suite ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org