dhruv9vats commented on a change in pull request #12724: URL: https://github.com/apache/arrow/pull/12724#discussion_r835806658
########## File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc ########## @@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) { DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel))); } +struct CastStructSubset { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + const auto& in_type = checked_cast<const StructType&>(*batch[0].type()); + const auto& out_type = checked_cast<const StructType&>(*out->type()); + const int in_field_count = in_type.num_fields(); + const int out_field_count = out_type.num_fields(); + + std::vector<bool> fields_to_select(in_field_count, false); + + int out_field_index = 0; + for (int in_field_index = 0; + in_field_index < in_field_count && out_field_index < out_field_count; + ++in_field_index) { + const auto in_field = in_type.field(in_field_index); + const auto out_field = out_type.field(out_field_index); + if (in_field->name() == out_field->name()) { + if (in_field->nullable() && !out_field->nullable()) { + return Status::TypeError("cannot cast nullable struct to non-nullable struct: ", + in_type.ToString(), " ", out_type.ToString()); + } + fields_to_select[in_field_index] = true; + ++out_field_index; + } + } + + if (out_field_index < out_field_count - 1) { + return Status::TypeError( + "struct subfields names don't match or are in the wrong order: ", + in_type.ToString(), " ", out_type.ToString()); + } + + if (out->kind() == Datum::SCALAR) { + const auto& in_scalar = checked_cast<const StructScalar&>(*batch[0].scalar()); + auto out_scalar = checked_cast<StructScalar*>(out->scalar().get()); + + DCHECK(!out_scalar->is_valid); + if (in_scalar.is_valid) { + out_field_index = 0; + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + auto values = in_scalar.value[in_field_index]; + auto target_type = out->type()->field(out_field_index++)->type(); + ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, options, + ctx->exec_context())); + DCHECK_EQ(Datum::SCALAR, cast_values.kind()); + out_scalar->value.push_back(cast_values.scalar()); + } + } + out_scalar->is_valid = true; + } + return Status::OK(); + } + + const ArrayData& in_array = *batch[0].array(); + ArrayData* out_array = out->mutable_array(); + + if (in_array.GetNullCount() > 0) { + auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool()); + const auto in_bitmap = in_array.buffers[0]->data(); + + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + if (bit_util::GetBit(in_bitmap, in_array.offset + in_field_index)) { + ARROW_RETURN_NOT_OK(out_bitmap_builder.Append(true)); + } else { + ARROW_RETURN_NOT_OK(out_bitmap_builder.Append(false)); + } + } + } + ARROW_ASSIGN_OR_RAISE(out_array->buffers[0], out_bitmap_builder.Finish()); + } + + out_field_index = 0; + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + auto values = + in_array.child_data[in_field_index]->Slice(in_array.offset, in_array.length); + auto target_type = out->type()->field(out_field_index++)->type(); + + ARROW_ASSIGN_OR_RAISE(Datum cast_values, + Cast(values, target_type, options, ctx->exec_context())); + + DCHECK_EQ(Datum::ARRAY, cast_values.kind()); + out_array->child_data.push_back(cast_values.array()); + } + } + + return Status::OK(); + } +}; + +void AddStructToStructSubsetCast(CastFunction* func) { + ScalarKernel kernel; + kernel.exec = CastStructSubset::Exec; + kernel.signature = + KernelSignature::Make({InputType(StructType::type_id)}, kOutputTargetType); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel))); +} Review comment: What would be the way to select this casting mechanism over the one that already exists for StructType? Will it include checking `out_type.num_fields()`? And selecting the current mechanism if `out_type.num_fields() < in_type.num_fields()`? If so, how could this be achieved? ########## File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc ########## @@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) { DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel))); } +struct CastStructSubset { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + const auto& in_type = checked_cast<const StructType&>(*batch[0].type()); + const auto& out_type = checked_cast<const StructType&>(*out->type()); + const int in_field_count = in_type.num_fields(); + const int out_field_count = out_type.num_fields(); + + std::vector<bool> fields_to_select(in_field_count, false); + + int out_field_index = 0; + for (int in_field_index = 0; + in_field_index < in_field_count && out_field_index < out_field_count; + ++in_field_index) { + const auto in_field = in_type.field(in_field_index); + const auto out_field = out_type.field(out_field_index); + if (in_field->name() == out_field->name()) { + if (in_field->nullable() && !out_field->nullable()) { + return Status::TypeError("cannot cast nullable struct to non-nullable struct: ", + in_type.ToString(), " ", out_type.ToString()); + } + fields_to_select[in_field_index] = true; + ++out_field_index; + } + } + + if (out_field_index < out_field_count - 1) { + return Status::TypeError( + "struct subfields names don't match or are in the wrong order: ", + in_type.ToString(), " ", out_type.ToString()); + } + + if (out->kind() == Datum::SCALAR) { + const auto& in_scalar = checked_cast<const StructScalar&>(*batch[0].scalar()); + auto out_scalar = checked_cast<StructScalar*>(out->scalar().get()); + + DCHECK(!out_scalar->is_valid); + if (in_scalar.is_valid) { + out_field_index = 0; + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + auto values = in_scalar.value[in_field_index]; + auto target_type = out->type()->field(out_field_index++)->type(); + ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, options, + ctx->exec_context())); + DCHECK_EQ(Datum::SCALAR, cast_values.kind()); + out_scalar->value.push_back(cast_values.scalar()); + } + } + out_scalar->is_valid = true; + } + return Status::OK(); + } + + const ArrayData& in_array = *batch[0].array(); + ArrayData* out_array = out->mutable_array(); + + if (in_array.GetNullCount() > 0) { + auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool()); + const auto in_bitmap = in_array.buffers[0]->data(); + + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + if (bit_util::GetBit(in_bitmap, in_array.offset + in_field_index)) { Review comment: Does the `in_array.offset + in_field_index` part in the line make sense? ########## File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc ########## @@ -252,6 +353,7 @@ std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() { // So is struct auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT); AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get()); + AddStructToStructSubsetCast(cast_struct.get()); Review comment: (Just a hacky way to temporarily call the current casting mechanism to test its output). Needs overhaul. ########## File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc ########## @@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) { DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel))); } +struct CastStructSubset { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + const auto& in_type = checked_cast<const StructType&>(*batch[0].type()); + const auto& out_type = checked_cast<const StructType&>(*out->type()); + const int in_field_count = in_type.num_fields(); + const int out_field_count = out_type.num_fields(); + + std::vector<bool> fields_to_select(in_field_count, false); + + int out_field_index = 0; + for (int in_field_index = 0; + in_field_index < in_field_count && out_field_index < out_field_count; + ++in_field_index) { + const auto in_field = in_type.field(in_field_index); + const auto out_field = out_type.field(out_field_index); + if (in_field->name() == out_field->name()) { + if (in_field->nullable() && !out_field->nullable()) { + return Status::TypeError("cannot cast nullable struct to non-nullable struct: ", + in_type.ToString(), " ", out_type.ToString()); + } + fields_to_select[in_field_index] = true; + ++out_field_index; + } + } + + if (out_field_index < out_field_count - 1) { + return Status::TypeError( + "struct subfields names don't match or are in the wrong order: ", + in_type.ToString(), " ", out_type.ToString()); + } + + if (out->kind() == Datum::SCALAR) { + const auto& in_scalar = checked_cast<const StructScalar&>(*batch[0].scalar()); + auto out_scalar = checked_cast<StructScalar*>(out->scalar().get()); + + DCHECK(!out_scalar->is_valid); + if (in_scalar.is_valid) { + out_field_index = 0; + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { + if (fields_to_select[in_field_index]) { + auto values = in_scalar.value[in_field_index]; + auto target_type = out->type()->field(out_field_index++)->type(); + ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, options, + ctx->exec_context())); + DCHECK_EQ(Datum::SCALAR, cast_values.kind()); + out_scalar->value.push_back(cast_values.scalar()); + } + } + out_scalar->is_valid = true; + } + return Status::OK(); + } + + const ArrayData& in_array = *batch[0].array(); + ArrayData* out_array = out->mutable_array(); + + if (in_array.GetNullCount() > 0) { + auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool()); + const auto in_bitmap = in_array.buffers[0]->data(); + + for (int in_field_index = 0; in_field_index < in_field_count; in_field_index++) { Review comment: Have separated the null bitmap construction from the main loop to avoid unnecessary branching when there are no nulls. Is this okay? ########## File path: cpp/src/arrow/compute/kernels/scalar_cast_test.cc ########## @@ -2329,6 +2329,43 @@ TEST(Cast, StructToSameSizedButDifferentNullabilityStruct) { Cast(src2, options)); } +TEST(Cats, StructSubset) { Review comment: ```suggestion TEST(Cast, StructSubset) { ``` 🐱 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org