westonpace commented on code in PR #34050:
URL: https://github.com/apache/arrow/pull/34050#discussion_r1116401947


##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -320,6 +320,38 @@ Result<compute::Expression> FromProto(const 
substrait::Expression& expr,
       return function_converter(substrait_call);
     }
 
+    case substrait::Expression::kCast: {
+      const auto& cast_exp = expr.cast();
+      ARROW_ASSIGN_OR_RAISE(auto input,
+                            FromProto(cast_exp.input(), ext_set, 
conversion_options));
+
+      ARROW_ASSIGN_OR_RAISE(auto type_nullable,
+                            FromProto(cast_exp.type(), ext_set, 
conversion_options));
+
+      if (!type_nullable.second &&
+          conversion_options.strictness == 
ConversionStrictness::EXACT_ROUNDTRIP) {
+        return Status::Invalid("Substrait cast type must be of nullable type");
+      }
+
+      if (cast_exp.failure_behavior() ==
+          substrait::Expression_Cast_FailureBehavior::
+              
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION) {
+        return compute::call("cast", {std::move(input)},
+                             compute::CastOptions::Safe(type_nullable.first));

Review Comment:
   ```suggestion
                                
compute::CastOptions::Safe(std::move(type_nullable.first)));
   ```
   minor nit



##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -320,6 +320,38 @@ Result<compute::Expression> FromProto(const 
substrait::Expression& expr,
       return function_converter(substrait_call);
     }
 
+    case substrait::Expression::kCast: {
+      const auto& cast_exp = expr.cast();
+      ARROW_ASSIGN_OR_RAISE(auto input,
+                            FromProto(cast_exp.input(), ext_set, 
conversion_options));
+
+      ARROW_ASSIGN_OR_RAISE(auto type_nullable,
+                            FromProto(cast_exp.type(), ext_set, 
conversion_options));
+
+      if (!type_nullable.second &&
+          conversion_options.strictness == 
ConversionStrictness::EXACT_ROUNDTRIP) {
+        return Status::Invalid("Substrait cast type must be of nullable type");
+      }
+
+      if (cast_exp.failure_behavior() ==
+          substrait::Expression_Cast_FailureBehavior::
+              
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION) {
+        return compute::call("cast", {std::move(input)},
+                             compute::CastOptions::Safe(type_nullable.first));
+      } else if (cast_exp.failure_behavior() ==
+                 substrait::Expression_Cast_FailureBehavior::
+                     
Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL) {
+        return Status::NotImplemented(
+            "Unsupported cast failure behavior: "
+            "Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL");

Review Comment:
   ```suggestion
               "FAILURE_BEHAVIOR_RETURN_NULL");
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
   }
 }
 
+TEST(Substrait, CallCast) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_NULLABLE"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+  ASSERT_TRUE(expr.call());
+
+  ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_NULLABLE"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_UNSPECIFIED"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+              Raises(StatusCode::Invalid));

Review Comment:
   ```suggestion
     EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
                 Raises(StatusCode::Invalid, HasSubstr("FailureBehavior 
unspecified")));
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
   }
 }
 
+TEST(Substrait, CallCast) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_NULLABLE"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+  ASSERT_TRUE(expr.call());
+
+  ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_NULLABLE"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_UNSPECIFIED"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+              Raises(StatusCode::Invalid));
+}
+
+TEST(Substrait, CallCastNonNullableFails) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+  conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_REQUIRED"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
+              Raises(StatusCode::Invalid));

Review Comment:
   ```suggestion
     EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
                 Raises(StatusCode::Invalid, HasSubstr("must be of nullable 
type")));
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
   }
 }
 
+TEST(Substrait, CallCast) {

Review Comment:
   ```suggestion
   TEST(Substrait, Cast) {
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
   }
 }
 
+TEST(Substrait, CallCast) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {

Review Comment:
   You use camelCase here but snake_case below (i.e. `direct_reference` and 
`struct_field`).  I was a bit surprised it works but it seems protobuf supports 
both.
   
   However, let's pick one and be consistent.



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -774,6 +774,127 @@ TEST(Substrait, CallExtensionFunction) {
   }
 }
 
+TEST(Substrait, CallCast) {
+  ExtensionSet ext_set;
+  ConversionOptions conversion_options;
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Expression", R"({
+  "selection": {
+      "directReference": {
+        "structField": {
+          "field": 0
+        }
+      },
+    "expression": {
+      "cast": {
+        "type": {
+          "fp64": {
+            "nullability": "NULLABILITY_NULLABLE"
+          }
+        },
+        "input": {
+          "selection": {
+            "direct_reference": {
+              "struct_field": {
+                "field": 0
+              }
+            }
+          }
+        },
+        "failure_behavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
+      }
+    }
+  }
+})",
+                                                   
/*ignore_unknown_fields=*/false))
+
+  ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+  ASSERT_TRUE(expr.call());
+
+  ASSERT_THAT(expr.call()->arguments[0].call()->function_name, "cast");
+}
+
+TEST(Substrait, CallCastRequiresFailureBehavior) {

Review Comment:
   ```suggestion
   TEST(Substrait, CastRequiresFailureBehavior) {
   ```
   Minor nit



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to