dhruv9vats commented on a change in pull request #12724:
URL: https://github.com/apache/arrow/pull/12724#discussion_r835806658



##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) {
   DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel)));
 }
 
+struct CastStructSubset {
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    const CastOptions& options = CastState::Get(ctx);
+    const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
+    const auto& out_type = checked_cast<const StructType&>(*out->type());
+    const int in_field_count = in_type.num_fields();
+    const int out_field_count = out_type.num_fields();
+
+    std::vector<bool> fields_to_select(in_field_count, false);
+
+    int out_field_index = 0;
+    for (int in_field_index = 0;
+         in_field_index < in_field_count && out_field_index < out_field_count;
+         ++in_field_index) {
+      const auto in_field = in_type.field(in_field_index);
+      const auto out_field = out_type.field(out_field_index);
+      if (in_field->name() == out_field->name()) {
+        if (in_field->nullable() && !out_field->nullable()) {
+          return Status::TypeError("cannot cast nullable struct to 
non-nullable struct: ",
+                                   in_type.ToString(), " ", 
out_type.ToString());
+        }
+        fields_to_select[in_field_index] = true;
+        ++out_field_index;
+      }
+    }
+
+    if (out_field_index < out_field_count - 1) {
+      return Status::TypeError(
+          "struct subfields names don't match or are in the wrong order: ",
+          in_type.ToString(), " ", out_type.ToString());
+    }
+
+    if (out->kind() == Datum::SCALAR) {
+      const auto& in_scalar = checked_cast<const 
StructScalar&>(*batch[0].scalar());
+      auto out_scalar = checked_cast<StructScalar*>(out->scalar().get());
+
+      DCHECK(!out_scalar->is_valid);
+      if (in_scalar.is_valid) {
+        out_field_index = 0;
+        for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+          if (fields_to_select[in_field_index]) {
+            auto values = in_scalar.value[in_field_index];
+            auto target_type = out->type()->field(out_field_index++)->type();
+            ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, 
options,
+                                                          
ctx->exec_context()));
+            DCHECK_EQ(Datum::SCALAR, cast_values.kind());
+            out_scalar->value.push_back(cast_values.scalar());
+          }
+        }
+        out_scalar->is_valid = true;
+      }
+      return Status::OK();
+    }
+
+    const ArrayData& in_array = *batch[0].array();
+    ArrayData* out_array = out->mutable_array();
+
+    if (in_array.GetNullCount() > 0) {
+      auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool());
+      const auto in_bitmap = in_array.buffers[0]->data();
+
+      for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+        if (fields_to_select[in_field_index]) {
+          if (bit_util::GetBit(in_bitmap, in_array.offset + in_field_index)) {
+            ARROW_RETURN_NOT_OK(out_bitmap_builder.Append(true));
+          } else {
+            ARROW_RETURN_NOT_OK(out_bitmap_builder.Append(false));
+          }
+        }
+      }
+      ARROW_ASSIGN_OR_RAISE(out_array->buffers[0], 
out_bitmap_builder.Finish());
+    }
+
+    out_field_index = 0;
+    for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+      if (fields_to_select[in_field_index]) {
+        auto values =
+            in_array.child_data[in_field_index]->Slice(in_array.offset, 
in_array.length);
+        auto target_type = out->type()->field(out_field_index++)->type();
+
+        ARROW_ASSIGN_OR_RAISE(Datum cast_values,
+                              Cast(values, target_type, options, 
ctx->exec_context()));
+
+        DCHECK_EQ(Datum::ARRAY, cast_values.kind());
+        out_array->child_data.push_back(cast_values.array());
+      }
+    }
+
+    return Status::OK();
+  }
+};
+
+void AddStructToStructSubsetCast(CastFunction* func) {
+  ScalarKernel kernel;
+  kernel.exec = CastStructSubset::Exec;
+  kernel.signature =
+      KernelSignature::Make({InputType(StructType::type_id)}, 
kOutputTargetType);
+  kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+  DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel)));
+}

Review comment:
       What would be the way to select this casting mechanism over the one that 
already exists for StructType? Will it include checking 
`out_type.num_fields()`? And selecting the current mechanism if 
`out_type.num_fields() < in_type.num_fields()`? If so, how could this be 
achieved?

##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) {
   DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel)));
 }
 
+struct CastStructSubset {
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    const CastOptions& options = CastState::Get(ctx);
+    const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
+    const auto& out_type = checked_cast<const StructType&>(*out->type());
+    const int in_field_count = in_type.num_fields();
+    const int out_field_count = out_type.num_fields();
+
+    std::vector<bool> fields_to_select(in_field_count, false);
+
+    int out_field_index = 0;
+    for (int in_field_index = 0;
+         in_field_index < in_field_count && out_field_index < out_field_count;
+         ++in_field_index) {
+      const auto in_field = in_type.field(in_field_index);
+      const auto out_field = out_type.field(out_field_index);
+      if (in_field->name() == out_field->name()) {
+        if (in_field->nullable() && !out_field->nullable()) {
+          return Status::TypeError("cannot cast nullable struct to 
non-nullable struct: ",
+                                   in_type.ToString(), " ", 
out_type.ToString());
+        }
+        fields_to_select[in_field_index] = true;
+        ++out_field_index;
+      }
+    }
+
+    if (out_field_index < out_field_count - 1) {
+      return Status::TypeError(
+          "struct subfields names don't match or are in the wrong order: ",
+          in_type.ToString(), " ", out_type.ToString());
+    }
+
+    if (out->kind() == Datum::SCALAR) {
+      const auto& in_scalar = checked_cast<const 
StructScalar&>(*batch[0].scalar());
+      auto out_scalar = checked_cast<StructScalar*>(out->scalar().get());
+
+      DCHECK(!out_scalar->is_valid);
+      if (in_scalar.is_valid) {
+        out_field_index = 0;
+        for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+          if (fields_to_select[in_field_index]) {
+            auto values = in_scalar.value[in_field_index];
+            auto target_type = out->type()->field(out_field_index++)->type();
+            ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, 
options,
+                                                          
ctx->exec_context()));
+            DCHECK_EQ(Datum::SCALAR, cast_values.kind());
+            out_scalar->value.push_back(cast_values.scalar());
+          }
+        }
+        out_scalar->is_valid = true;
+      }
+      return Status::OK();
+    }
+
+    const ArrayData& in_array = *batch[0].array();
+    ArrayData* out_array = out->mutable_array();
+
+    if (in_array.GetNullCount() > 0) {
+      auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool());
+      const auto in_bitmap = in_array.buffers[0]->data();
+
+      for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+        if (fields_to_select[in_field_index]) {
+          if (bit_util::GetBit(in_bitmap, in_array.offset + in_field_index)) {

Review comment:
       Does the `in_array.offset + in_field_index` part in the line make sense?

##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -252,6 +353,7 @@ std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() 
{
   // So is struct
   auto cast_struct = std::make_shared<CastFunction>("cast_struct", 
Type::STRUCT);
   AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get());
+  AddStructToStructSubsetCast(cast_struct.get());

Review comment:
       (Just a hacky way to temporarily call the current casting mechanism to 
test its output). Needs overhaul.

##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -228,6 +228,107 @@ void AddStructToStructCast(CastFunction* func) {
   DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel)));
 }
 
+struct CastStructSubset {
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    const CastOptions& options = CastState::Get(ctx);
+    const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
+    const auto& out_type = checked_cast<const StructType&>(*out->type());
+    const int in_field_count = in_type.num_fields();
+    const int out_field_count = out_type.num_fields();
+
+    std::vector<bool> fields_to_select(in_field_count, false);
+
+    int out_field_index = 0;
+    for (int in_field_index = 0;
+         in_field_index < in_field_count && out_field_index < out_field_count;
+         ++in_field_index) {
+      const auto in_field = in_type.field(in_field_index);
+      const auto out_field = out_type.field(out_field_index);
+      if (in_field->name() == out_field->name()) {
+        if (in_field->nullable() && !out_field->nullable()) {
+          return Status::TypeError("cannot cast nullable struct to 
non-nullable struct: ",
+                                   in_type.ToString(), " ", 
out_type.ToString());
+        }
+        fields_to_select[in_field_index] = true;
+        ++out_field_index;
+      }
+    }
+
+    if (out_field_index < out_field_count - 1) {
+      return Status::TypeError(
+          "struct subfields names don't match or are in the wrong order: ",
+          in_type.ToString(), " ", out_type.ToString());
+    }
+
+    if (out->kind() == Datum::SCALAR) {
+      const auto& in_scalar = checked_cast<const 
StructScalar&>(*batch[0].scalar());
+      auto out_scalar = checked_cast<StructScalar*>(out->scalar().get());
+
+      DCHECK(!out_scalar->is_valid);
+      if (in_scalar.is_valid) {
+        out_field_index = 0;
+        for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {
+          if (fields_to_select[in_field_index]) {
+            auto values = in_scalar.value[in_field_index];
+            auto target_type = out->type()->field(out_field_index++)->type();
+            ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, 
options,
+                                                          
ctx->exec_context()));
+            DCHECK_EQ(Datum::SCALAR, cast_values.kind());
+            out_scalar->value.push_back(cast_values.scalar());
+          }
+        }
+        out_scalar->is_valid = true;
+      }
+      return Status::OK();
+    }
+
+    const ArrayData& in_array = *batch[0].array();
+    ArrayData* out_array = out->mutable_array();
+
+    if (in_array.GetNullCount() > 0) {
+      auto out_bitmap_builder = TypedBufferBuilder<bool>(ctx->memory_pool());
+      const auto in_bitmap = in_array.buffers[0]->data();
+
+      for (int in_field_index = 0; in_field_index < in_field_count; 
in_field_index++) {

Review comment:
       Have separated the null bitmap construction from the main loop to avoid 
unnecessary branching when there are no nulls. Is this okay?

##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_test.cc
##########
@@ -2329,6 +2329,43 @@ TEST(Cast, 
StructToSameSizedButDifferentNullabilityStruct) {
       Cast(src2, options));
 }
 
+TEST(Cats, StructSubset) {

Review comment:
       ```suggestion
   TEST(Cast, StructSubset) {
   ```
   🐱 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to