westonpace commented on code in PR #34050:
URL: https://github.com/apache/arrow/pull/34050#discussion_r1116401947
##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -320,6 +320,38 @@ Result<compute::Expression> FromProto(const
substrait::Expression& expr,
return function_converter(substrait_call);
}
+ case substrait::Expression::kCast: {
+ const auto& cast_exp = expr.cast();
+ ARROW_ASSIGN_OR_RAISE(auto input,
+ FromProto(cast_exp.input(), ext_set,
conversion_options));
+
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable,
+ FromProto(cast_exp.type(), ext_set,
conversion_options));
+
+ if (!type_nullable.second &&
+ conversion_options.strictness ==
ConversionStrictness::EXACT_ROUNDTRIP) {
+ return Status::Invalid("Substrait cast type must be of nullable type");
+ }
+
+ if (cast_exp.failure_behavior() ==
+ substrait::Expression_Cast_FailureBehavior::
+
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION) {
+ return compute::call("cast", {std::move(input)},
+ compute::CastOptions::Safe(type_nullable.first));
Review Comment:
```suggestion
compute::CastOptions::Safe(std::move(type_nullable.first)));
```
minor nit
##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -320,6 +320,38 @@ Result<compute::Expression> FromProto(const
substrait::Expression& expr,
return function_converter(substrait_call);
}
+ case substrait::Expression::kCast: {
+ const auto& cast_exp = expr.cast();
+ ARROW_ASSIGN_OR_RAISE(auto input,
+ FromProto(cast_exp.input(), ext_set,
conversion_options));
+
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable,
+ FromProto(cast_exp.type(), ext_set,
conversion_options));
+
+ if (!type_nullable.second &&
+ conversion_options.strictness ==
ConversionStrictness::EXACT_ROUNDTRIP) {
+ return Status::Invalid("Substrait cast type must be of nullable type");
+ }
+
+ if (cast_exp.failure_behavior() ==
+ substrait::Expression_Cast_FailureBehavior::
+
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION) {
+ return compute::call("cast", {std::move(input)},
+ compute::CastOptions::Safe(type_nullable.first));
+ } else if (cast_exp.failure_behavior() ==
+ substrait::Expression_Cast_FailureBehavior::
+
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL) {
+ return Status::NotImplemented(
+ "Unsupported cast failure behavior: "
+ "Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL");
Review Comment:
```suggestion
"FAILURE_BEHAVIOR_RETURN_NULL");
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
}
}
+TEST(Substrait, CallCast) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+ ASSERT_TRUE(expr.call());
+
+ ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_UNSPECIFIED"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+ Raises(StatusCode::Invalid));
Review Comment:
```suggestion
EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
Raises(StatusCode::Invalid, HasSubstr("FailureBehavior
unspecified")));
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
}
}
+TEST(Substrait, CallCast) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+ ASSERT_TRUE(expr.call());
+
+ ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_UNSPECIFIED"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+ Raises(StatusCode::Invalid));
+}
+
+TEST(Substrait, CallCastNonNullableFails) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+ conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+ Raises(StatusCode::Invalid));
Review Comment:
```suggestion
EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
Raises(StatusCode::Invalid, HasSubstr("must be of nullable
type")));
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
}
}
+TEST(Substrait, CallCast) {
Review Comment:
```suggestion
TEST(Substrait, Cast) {
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
}
}
+TEST(Substrait, CallCast) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
Review Comment:
You use camelCase here but snake_case below (i.e. `direct_reference` and
`struct_field`). I was a bit surprised it works but it seems protobuf supports
both.
However, let's pick one and be consistent.
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
}
}
+TEST(Substrait, CallCast) {
+ ExtensionSet ext_set;
+ ConversionOptions conversion_options;
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "cast": {
+ "type": {
+ "fp64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "input": {
+ "selection": {
+ "direct_reference": {
+ "struct_field": {
+ "field": 0
+ }
+ }
+ }
+ },
+ "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+ }
+ }
+ }
+})",
+
/*ignore_unknown_fields=*/false))
+
+ ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+ ASSERT_TRUE(expr.call());
+
+ ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {
Review Comment:
```suggestion
TEST(Substrait, CastRequiresFailureBehavior) {
```
Minor nit
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]