fsk119 commented on code in PR #26583:
URL: https://github.com/apache/flink/pull/26583#discussion_r2106437244
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java:
##########
@@ -49,6 +59,75 @@ public static void adjustTypeForMapConstructor(
}
}
+ public static boolean throwValidationSignatureErrorOrReturnFalse(
+ SqlCallBinding callBinding, boolean throwOnFailure) {
+ if (throwOnFailure) {
+ throw callBinding.newValidationSignatureError();
+ } else {
+ return false;
+ }
+ }
+
+ @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+ public static boolean throwExceptionOrReturnFalse(
+ Optional<RuntimeException> e, boolean throwOnFailure) {
+ if (e.isPresent()) {
+ if (throwOnFailure) {
+ throw e.get();
+ } else {
+ return false;
+ }
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * Checks whether the heading operands are in the form {@code (ROW,
DESCRIPTOR, DESCRIPTOR ...,
+ * other params)}, returning whether successful, and throwing if any
columns are not found.
+ *
+ * @param callBinding The call binding
+ * @param descriptorStartPos The position of the first descriptor operand
+ * @param descriptorCount The number of descriptors following the first
operand (e.g. the table)
+ * @return true if validation passes; throws if any columns are not found
+ */
+ public static boolean checkTableAndDescriptorOperands(
+ SqlCallBinding callBinding, int descriptorStartPos, int
descriptorCount) {
Review Comment:
Why don't we just use the descriptor location as the method inputs? The
descriptor count is always 1 in all cases. I think we don't need to complicate
the cases here.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -87,21 +111,25 @@ public List<String> paramNames() {
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean
throwOnFailure) {
- // TODO: FLINK-37780 Check operand types after integrated with
SqlExplicitModelCall in
- // validator
- return false;
+ if
(!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2, 1)) {
Review Comment:
nit: Do we need to validate this again? SqlMLTableFunction#validateCall has
already validated it.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String
opName) {
return opName
+ "(TABLE table_name, MODEL model_name,
DESCRIPTOR(input_columns), [MAP[]]";
}
+
+ private static Optional<RuntimeException>
checkModelSignature(SqlCallBinding callBinding) {
+ SqlValidator validator = callBinding.getValidator();
+
+ // Check second operand is SqlModelCall
+ if (!(callBinding.operand(1) instanceof SqlModelCall)) {
+ return Optional.of(
+ new ValidationException("Second operand must be a
model identifier."));
+ }
+
+ // Get descriptor columns
+ SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
+ List<SqlNode> descriptCols = descriptorCall.getOperandList();
+
+ // Get model input size
+ SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
+ RelDataType modelInputType = modelCall.getInputType(validator);
+
+ // Check sizes match
+ if (descriptCols.size() != modelInputType.getFieldCount()) {
+ return Optional.of(
+ new ValidationException(
+ String.format(
+ "Number of descriptor input columns
(%d) does not match model input size (%d)",
+ descriptCols.size(),
modelInputType.getFieldCount())));
+ }
+
+ // Check types match
+ final RelDataType tableType =
validator.getValidatedNodeType(callBinding.operand(0));
+ final SqlNameMatcher matcher =
validator.getCatalogReader().nameMatcher();
+ for (int i = 0; i < descriptCols.size(); i++) {
+ SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
+ String descriptColName =
+ columnName.isSimple()
+ ? columnName.getSimple()
+ : Util.last(columnName.names);
Review Comment:
I think columns that are specified in the descriptors are always simple. cc
SqlDescriptorOperator#checkOperandTypes. BTW, if the columns are not simple, I
think it means it refers to the nested column. Here I think we should throw an
exception to notify user it is unsupported feature.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String
opName) {
return opName
+ "(TABLE table_name, MODEL model_name,
DESCRIPTOR(input_columns), [MAP[]]";
}
+
+ private static Optional<RuntimeException>
checkModelSignature(SqlCallBinding callBinding) {
+ SqlValidator validator = callBinding.getValidator();
+
+ // Check second operand is SqlModelCall
+ if (!(callBinding.operand(1) instanceof SqlModelCall)) {
+ return Optional.of(
+ new ValidationException("Second operand must be a
model identifier."));
+ }
+
+ // Get descriptor columns
+ SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
+ List<SqlNode> descriptCols = descriptorCall.getOperandList();
+
+ // Get model input size
+ SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
+ RelDataType modelInputType = modelCall.getInputType(validator);
+
+ // Check sizes match
+ if (descriptCols.size() != modelInputType.getFieldCount()) {
+ return Optional.of(
+ new ValidationException(
+ String.format(
+ "Number of descriptor input columns
(%d) does not match model input size (%d)",
+ descriptCols.size(),
modelInputType.getFieldCount())));
+ }
+
+ // Check types match
+ final RelDataType tableType =
validator.getValidatedNodeType(callBinding.operand(0));
+ final SqlNameMatcher matcher =
validator.getCatalogReader().nameMatcher();
+ for (int i = 0; i < descriptCols.size(); i++) {
+ SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
+ String descriptColName =
+ columnName.isSimple()
+ ? columnName.getSimple()
+ : Util.last(columnName.names);
+ int index = matcher.indexOf(tableType.getFieldNames(),
descriptColName);
+ RelDataType sourceType =
tableType.getFieldList().get(index).getType();
+ RelDataType targetType =
modelInputType.getFieldList().get(i).getType();
+
+ LogicalType sourceLogicalType = toLogicalType(sourceType);
+ LogicalType targetLogicalType = toLogicalType(targetType);
+
+ if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType,
targetLogicalType)) {
+ return Optional.of(
+ new ValidationException(
+ String.format(
+ "Descriptor column type %s cannot
be assigned to model input type %s at position %d",
+ sourceLogicalType,
targetLogicalType, i)));
+ }
+ }
+
+ return Optional.empty();
+ }
+
+ private static Optional<RuntimeException> checkConfig(SqlCallBinding
callBinding) {
+ if (callBinding.getOperandCount() < PARAM_NAMES.size()) {
+ return Optional.empty();
+ }
+
+ SqlNode configNode = callBinding.operand(3);
+ if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) {
+ return Optional.of(new ValidationException("Config param
should be a MAP."));
+ }
+
+ // Map operands can only be SqlCharStringLiteral or cast of
SqlCharStringLiteral
+ SqlCall mapCall = (SqlCall) configNode;
+ for (int i = 0; i < mapCall.operandCount(); i++) {
+ SqlNode operand = mapCall.operand(i);
+ if (operand instanceof SqlCharStringLiteral) {
+ continue;
+ }
+ if (operand.getKind().equals(SqlKind.CAST)) {
Review Comment:
I am not sure how to use this. Could you add a test case for this?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -61,9 +78,16 @@ public boolean argumentMustBeScalar(int ordinal) {
@Override
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
- // TODO: FLINK-37780 output type based on table schema and model
output schema
- // model output schema to be available after integrated with
SqlExplicitModelCall
- return opBinding.getOperandType(1);
+ final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+ final RelDataType inputRowType = opBinding.getOperandType(0);
+ final RelDataType modelOutputRowType = opBinding.getOperandType(1);
+
+ return typeFactory
+ .builder()
+ .kind(inputRowType.getStructKind())
+ .addAll(inputRowType.getFieldList())
Review Comment:
Take a look at SystemOutputStrategy#inferType.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String
opName) {
return opName
+ "(TABLE table_name, MODEL model_name,
DESCRIPTOR(input_columns), [MAP[]]";
Review Comment:
Do you miss the `)` here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]