lidavidm commented on a change in pull request #12248:
URL: https://github.com/apache/arrow/pull/12248#discussion_r802636253
##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_test.cc
##########
@@ -2218,6 +2218,117 @@ TEST(Cast, ListToListOptionsPassthru) {
}
}
+static void CheckStructToStruct(
+ const std::vector<std::shared_ptr<DataType>>& value_types) {
+ for (const auto& src_value_type : value_types) {
+ for (const auto& dest_value_type : value_types) {
+ std::vector<std::string> field_names = {"a", "b"};
+ std::shared_ptr<Array> a1, b1, a2, b2;
+ a1 = ArrayFromJSON(src_value_type, "[1, 2, 3, 4, 5]");
+ b1 = ArrayFromJSON(src_value_type, "[6, 7, 8, 9, 0]");
+ a2 = ArrayFromJSON(dest_value_type, "[1, 2, 3, 4, 5]");
+ b2 = ArrayFromJSON(dest_value_type, "[6, 7, 8, 9, 0]");
+ ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a1, b1}, field_names));
+ ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a2, b2},
field_names));
+
+ CheckCast(src, dest);
+ }
+ }
+}
+
+TEST(Cast, StructToSameSizedAndNamedStruct) {
+ CheckStructToStruct({int32(), float32(), int64()});
+}
+
+TEST(Cast, StructToSameSizedButDifferentNamedStruct) {
+ std::vector<std::string> field_names = {"a", "b"};
+ std::shared_ptr<Array> a, b;
+ a = ArrayFromJSON(int8(), "[1, 2]");
+ b = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names));
+
+ std::vector<std::string> field_names2 = {"c", "d"};
+ std::shared_ptr<Array> c, d;
+ c = ArrayFromJSON(int8(), "[1, 2]");
+ d = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({c, d}, field_names2));
+ auto options = CastOptions::Safe(dest->type());
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("Type error: struct field names do not match:
struct<a: int8, "
+ "b: int8> struct<c: int8, d: int8>"),
+ Cast(src, options));
+}
+
+TEST(Cast, StructToDifferentSizeStruct) {
+ std::vector<std::string> field_names = {"a", "b"};
+ std::shared_ptr<Array> a, b;
+ a = ArrayFromJSON(int8(), "[1, 2]");
+ b = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names));
+
+ std::vector<std::string> field_names2 = {"a", "b", "c"};
+ std::shared_ptr<Array> a2, b2, c;
+ a2 = ArrayFromJSON(int8(), "[1, 2]");
+ b2 = ArrayFromJSON(int8(), "[3, 4]");
+ c = ArrayFromJSON(int8(), "[5, 6]");
+ ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a2, b2, c},
field_names2));
+ auto options = CastOptions::Safe(dest->type());
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("Type error: struct field sizes do not match:
struct<a: int8, "
+ "b: int8> struct<a: int8, b: int8, c: int8>"),
+ Cast(src, options));
+}
+
+TEST(Cast, StructToSameSizedButDifferentNullabilityStruct) {
+ // OK to go from not-nullable to nullable...
+ std::vector<std::shared_ptr<Field>> fields1 = {
+ std::make_shared<Field>("a", int8(), false),
+ std::make_shared<Field>("b", int8(), false)};
+ std::shared_ptr<Array> a1, b1;
+ a1 = ArrayFromJSON(int8(), "[1, 2]");
+ b1 = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto src1, StructArray::Make({a1, b1}, fields1));
+
+ std::vector<std::shared_ptr<Field>> fields2 = {
+ std::make_shared<Field>("a", int8(), true),
+ std::make_shared<Field>("b", int8(), true)};
+ std::shared_ptr<Array> a2, b2;
+ a2 = ArrayFromJSON(int8(), "[1, 2]");
+ b2 = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto dest1, StructArray::Make({a2, b2}, fields2));
+
+ CheckCast(src1, dest1);
+
+ // But not the other way around
+ std::vector<std::shared_ptr<Field>> fields3 = {
+ std::make_shared<Field>("a", int8(), true),
+ std::make_shared<Field>("b", int8(), true)};
+ std::shared_ptr<Array> a3, b3;
+ a3 = ArrayFromJSON(int8(), "[1, null]");
+ b3 = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto src2, StructArray::Make({a3, b3}, fields3));
+
+ std::vector<std::shared_ptr<Field>> fields4 = {
+ std::make_shared<Field>("a", int8(), false),
+ std::make_shared<Field>("b", int8(), false)};
+ std::shared_ptr<Array> a4, b4;
+ a4 = ArrayFromJSON(int8(), "[1, 2]");
+ b4 = ArrayFromJSON(int8(), "[3, 4]");
+ ASSERT_OK_AND_ASSIGN(auto dest2, StructArray::Make({a4, b4}, fields4));
+ auto options = CastOptions::Safe(dest2->type());
Review comment:
FWIW, instead of constructing dest2, why not just use `arrow::struct_()`
to directly construct the type?
##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -150,6 +150,80 @@ void AddListCast(CastFunction* func) {
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}
+struct CastStruct {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+ const auto in_field_count =
+ checked_cast<const StructType&>(*batch[0].type()).num_fields();
+ const auto out_field_count =
+ checked_cast<const StructType&>(*out->type()).num_fields();
+
+ if (in_field_count != out_field_count) {
+ return Status::TypeError("struct field sizes do not match: ",
+ batch[0].type()->ToString(), " ",
out->type()->ToString());
+ }
+
+ for (int64_t i = 0; i < in_field_count; ++i) {
+ const auto in_field_name =
+ checked_cast<const StructType&>(*batch[0].type()).field(i)->name();
+ const auto out_field_name =
+ checked_cast<const StructType&>(*out->type()).field(i)->name();
Review comment:
Looks good to me now.
##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -150,6 +150,104 @@ void AddListCast(CastFunction* func) {
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}
+struct CastStruct {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+ const auto in_field_count =
+ checked_cast<const StructType&>(*batch[0].type()).num_fields();
+ const auto out_field_count =
+ checked_cast<const StructType&>(*out->type()).num_fields();
+
+ if (in_field_count != out_field_count) {
+ return Status::TypeError("struct field sizes do not match: ",
+ batch[0].type()->ToString(), " ",
out->type()->ToString());
+ }
+
+ for (int i = 0; i < in_field_count; ++i) {
+ const auto in_field_name =
+ checked_cast<const StructType&>(*batch[0].type()).field(i)->name();
+ const auto out_field_name =
+ checked_cast<const StructType&>(*out->type()).field(i)->name();
+ if (in_field_name != out_field_name) {
+ return Status::TypeError(
+ "struct field names do not match: ", batch[0].type()->ToString(),
" ",
+ out->type()->ToString());
+ }
+
+ const auto in_field_nullable =
+ checked_cast<const
StructType&>(*batch[0].type()).field(i)->nullable();
+ const auto out_field_nullable =
+ checked_cast<const StructType&>(*out->type()).field(i)->nullable();
+
+ if (in_field_nullable && !out_field_nullable) {
+ return Status::TypeError("cannot cast nullable struct to non-nullable
struct: ",
+ batch[0].type()->ToString(), " ",
+ out->type()->ToString());
+ }
+ }
+
+ for (int i = 0; i < in_field_count; ++i) {
+ const auto in_field_name =
+ checked_cast<const StructType&>(*batch[0].type()).field(i)->name();
+ const auto out_field_name =
+ checked_cast<const StructType&>(*out->type()).field(i)->name();
+ if (in_field_name != out_field_name) {
+ return Status::TypeError(
+ "struct field names do not match: ", batch[0].type()->ToString(),
" ",
+ out->type()->ToString());
+ }
+ }
+
+ if (out->kind() == Datum::SCALAR) {
+ const auto& in_scalar = checked_cast<const
StructScalar&>(*batch[0].scalar());
+ auto out_scalar = checked_cast<StructScalar*>(out->scalar().get());
+
+ DCHECK(!out_scalar->is_valid);
+ if (in_scalar.is_valid) {
+ for (int i = 0; i < in_field_count; i++) {
+ auto values = in_scalar.value[i];
+ auto target_type = out->type()->field(i)->type();
+ ARROW_ASSIGN_OR_RAISE(Datum cast_values,
+ Cast(values, target_type, options,
ctx->exec_context()));
+ DCHECK_EQ(Datum::SCALAR, cast_values.kind());
+ out_scalar->value.push_back(cast_values.scalar());
+ }
+ out_scalar->is_valid = true;
+ }
+ return Status::OK();
+ }
+
+ const ArrayData& in_array = *batch[0].array();
+ ArrayData* out_array = out->mutable_array();
+ out_array->buffers = in_array.buffers;
Review comment:
Now that I look at this, I think there's still a missing case here. We
should copy the offset over as well (and then we should not slice the child
arrays below), or if there's a top-level validity buffer, we should slice the
buffer before copying. I don't think the tests notice this right now because
there's no top-level validity buffer in any of the test cases.
##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -150,6 +150,104 @@ void AddListCast(CastFunction* func) {
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}
+struct CastStruct {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+ const auto in_field_count =
+ checked_cast<const StructType&>(*batch[0].type()).num_fields();
Review comment:
nit, but we could factor out the checked_cast of the types here and below
##########
File path: cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
##########
@@ -150,6 +150,104 @@ void AddListCast(CastFunction* func) {
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}
+struct CastStruct {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+ const auto in_field_count =
+ checked_cast<const StructType&>(*batch[0].type()).num_fields();
+ const auto out_field_count =
+ checked_cast<const StructType&>(*out->type()).num_fields();
+
+ if (in_field_count != out_field_count) {
+ return Status::TypeError("struct field sizes do not match: ",
+ batch[0].type()->ToString(), " ",
out->type()->ToString());
+ }
+
+ for (int i = 0; i < in_field_count; ++i) {
+ const auto in_field_name =
+ checked_cast<const StructType&>(*batch[0].type()).field(i)->name();
+ const auto out_field_name =
+ checked_cast<const StructType&>(*out->type()).field(i)->name();
+ if (in_field_name != out_field_name) {
+ return Status::TypeError(
+ "struct field names do not match: ", batch[0].type()->ToString(),
" ",
+ out->type()->ToString());
+ }
+
+ const auto in_field_nullable =
+ checked_cast<const
StructType&>(*batch[0].type()).field(i)->nullable();
+ const auto out_field_nullable =
+ checked_cast<const StructType&>(*out->type()).field(i)->nullable();
+
+ if (in_field_nullable && !out_field_nullable) {
+ return Status::TypeError("cannot cast nullable struct to non-nullable
struct: ",
+ batch[0].type()->ToString(), " ",
+ out->type()->ToString());
+ }
+ }
+
+ for (int i = 0; i < in_field_count; ++i) {
+ const auto in_field_name =
+ checked_cast<const StructType&>(*batch[0].type()).field(i)->name();
+ const auto out_field_name =
+ checked_cast<const StructType&>(*out->type()).field(i)->name();
+ if (in_field_name != out_field_name) {
+ return Status::TypeError(
+ "struct field names do not match: ", batch[0].type()->ToString(),
" ",
+ out->type()->ToString());
+ }
+ }
+
+ if (out->kind() == Datum::SCALAR) {
+ const auto& in_scalar = checked_cast<const
StructScalar&>(*batch[0].scalar());
+ auto out_scalar = checked_cast<StructScalar*>(out->scalar().get());
+
+ DCHECK(!out_scalar->is_valid);
+ if (in_scalar.is_valid) {
+ for (int i = 0; i < in_field_count; i++) {
+ auto values = in_scalar.value[i];
+ auto target_type = out->type()->field(i)->type();
+ ARROW_ASSIGN_OR_RAISE(Datum cast_values,
+ Cast(values, target_type, options,
ctx->exec_context()));
+ DCHECK_EQ(Datum::SCALAR, cast_values.kind());
+ out_scalar->value.push_back(cast_values.scalar());
+ }
+ out_scalar->is_valid = true;
+ }
+ return Status::OK();
+ }
+
+ const ArrayData& in_array = *batch[0].array();
+ ArrayData* out_array = out->mutable_array();
+ out_array->buffers = in_array.buffers;
Review comment:
`StructArray::Make` does take a null bitmap so we should be able to test
this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]