fsk119 commented on code in PR #26583:
URL: https://github.com/apache/flink/pull/26583#discussion_r2106437244


##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java:
##########
@@ -49,6 +59,75 @@ public static void adjustTypeForMapConstructor(
         }
     }
 
+    public static boolean throwValidationSignatureErrorOrReturnFalse(
+            SqlCallBinding callBinding, boolean throwOnFailure) {
+        if (throwOnFailure) {
+            throw callBinding.newValidationSignatureError();
+        } else {
+            return false;
+        }
+    }
+
+    @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+    public static boolean throwExceptionOrReturnFalse(
+            Optional<RuntimeException> e, boolean throwOnFailure) {
+        if (e.isPresent()) {
+            if (throwOnFailure) {
+                throw e.get();
+            } else {
+                return false;
+            }
+        } else {
+            return true;
+        }
+    }
+
+    /**
+     * Checks whether the heading operands are in the form {@code (ROW, 
DESCRIPTOR, DESCRIPTOR ...,
+     * other params)}, returning whether successful, and throwing if any 
columns are not found.
+     *
+     * @param callBinding The call binding
+     * @param descriptorStartPos The position of the first descriptor operand
+     * @param descriptorCount The number of descriptors following the first 
operand (e.g. the table)
+     * @return true if validation passes; throws if any columns are not found
+     */
+    public static boolean checkTableAndDescriptorOperands(
+            SqlCallBinding callBinding, int descriptorStartPos, int 
descriptorCount) {

Review Comment:
   Why don't we just use the descriptor location as the method inputs? The 
descriptor count is always 1 in all cases. I think we don't need to complicate 
the cases here.



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -87,21 +111,25 @@ public List<String> paramNames() {
 
         @Override
         public boolean checkOperandTypes(SqlCallBinding callBinding, boolean 
throwOnFailure) {
-            // TODO: FLINK-37780 Check operand types after integrated with 
SqlExplicitModelCall in
-            // validator
-            return false;
+            if 
(!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2, 1)) {

Review Comment:
   nit: Do we need to validate this again? SqlMLTableFunction#validateCall has 
already validated it.



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String 
opName) {
             return opName
                     + "(TABLE table_name, MODEL model_name, 
DESCRIPTOR(input_columns), [MAP[]]";
         }
+
+        private static Optional<RuntimeException> 
checkModelSignature(SqlCallBinding callBinding) {
+            SqlValidator validator = callBinding.getValidator();
+
+            // Check second operand is SqlModelCall
+            if (!(callBinding.operand(1) instanceof SqlModelCall)) {
+                return Optional.of(
+                        new ValidationException("Second operand must be a 
model identifier."));
+            }
+
+            // Get descriptor columns
+            SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
+            List<SqlNode> descriptCols = descriptorCall.getOperandList();
+
+            // Get model input size
+            SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
+            RelDataType modelInputType = modelCall.getInputType(validator);
+
+            // Check sizes match
+            if (descriptCols.size() != modelInputType.getFieldCount()) {
+                return Optional.of(
+                        new ValidationException(
+                                String.format(
+                                        "Number of descriptor input columns 
(%d) does not match model input size (%d)",
+                                        descriptCols.size(), 
modelInputType.getFieldCount())));
+            }
+
+            // Check types match
+            final RelDataType tableType = 
validator.getValidatedNodeType(callBinding.operand(0));
+            final SqlNameMatcher matcher = 
validator.getCatalogReader().nameMatcher();
+            for (int i = 0; i < descriptCols.size(); i++) {
+                SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
+                String descriptColName =
+                        columnName.isSimple()
+                                ? columnName.getSimple()
+                                : Util.last(columnName.names);

Review Comment:
   I think columns that are specified in the descriptors are always simple. cc 
SqlDescriptorOperator#checkOperandTypes. BTW, if the columns are not simple, I 
think it means it refers to the nested column.  Here I think we should throw an 
exception to notify user it is unsupported feature.



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String 
opName) {
             return opName
                     + "(TABLE table_name, MODEL model_name, 
DESCRIPTOR(input_columns), [MAP[]]";
         }
+
+        private static Optional<RuntimeException> 
checkModelSignature(SqlCallBinding callBinding) {
+            SqlValidator validator = callBinding.getValidator();
+
+            // Check second operand is SqlModelCall
+            if (!(callBinding.operand(1) instanceof SqlModelCall)) {
+                return Optional.of(
+                        new ValidationException("Second operand must be a 
model identifier."));
+            }
+
+            // Get descriptor columns
+            SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
+            List<SqlNode> descriptCols = descriptorCall.getOperandList();
+
+            // Get model input size
+            SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
+            RelDataType modelInputType = modelCall.getInputType(validator);
+
+            // Check sizes match
+            if (descriptCols.size() != modelInputType.getFieldCount()) {
+                return Optional.of(
+                        new ValidationException(
+                                String.format(
+                                        "Number of descriptor input columns 
(%d) does not match model input size (%d)",
+                                        descriptCols.size(), 
modelInputType.getFieldCount())));
+            }
+
+            // Check types match
+            final RelDataType tableType = 
validator.getValidatedNodeType(callBinding.operand(0));
+            final SqlNameMatcher matcher = 
validator.getCatalogReader().nameMatcher();
+            for (int i = 0; i < descriptCols.size(); i++) {
+                SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
+                String descriptColName =
+                        columnName.isSimple()
+                                ? columnName.getSimple()
+                                : Util.last(columnName.names);
+                int index = matcher.indexOf(tableType.getFieldNames(), 
descriptColName);
+                RelDataType sourceType = 
tableType.getFieldList().get(index).getType();
+                RelDataType targetType = 
modelInputType.getFieldList().get(i).getType();
+
+                LogicalType sourceLogicalType = toLogicalType(sourceType);
+                LogicalType targetLogicalType = toLogicalType(targetType);
+
+                if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, 
targetLogicalType)) {
+                    return Optional.of(
+                            new ValidationException(
+                                    String.format(
+                                            "Descriptor column type %s cannot 
be assigned to model input type %s at position %d",
+                                            sourceLogicalType, 
targetLogicalType, i)));
+                }
+            }
+
+            return Optional.empty();
+        }
+
+        private static Optional<RuntimeException> checkConfig(SqlCallBinding 
callBinding) {
+            if (callBinding.getOperandCount() < PARAM_NAMES.size()) {
+                return Optional.empty();
+            }
+
+            SqlNode configNode = callBinding.operand(3);
+            if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) {
+                return Optional.of(new ValidationException("Config param 
should be a MAP."));
+            }
+
+            // Map operands can only be SqlCharStringLiteral or cast of 
SqlCharStringLiteral
+            SqlCall mapCall = (SqlCall) configNode;
+            for (int i = 0; i < mapCall.operandCount(); i++) {
+                SqlNode operand = mapCall.operand(i);
+                if (operand instanceof SqlCharStringLiteral) {
+                    continue;
+                }
+                if (operand.getKind().equals(SqlKind.CAST)) {

Review Comment:
   I am not sure how to use this. Could you add a test case for this?



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -61,9 +78,16 @@ public boolean argumentMustBeScalar(int ordinal) {
 
     @Override
     protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
-        // TODO: FLINK-37780 output type based on table schema and model 
output schema
-        // model output schema to be available after integrated with 
SqlExplicitModelCall
-        return opBinding.getOperandType(1);
+        final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+        final RelDataType inputRowType = opBinding.getOperandType(0);
+        final RelDataType modelOutputRowType = opBinding.getOperandType(1);
+
+        return typeFactory
+                .builder()
+                .kind(inputRowType.getStructKind())
+                .addAll(inputRowType.getFieldList())

Review Comment:
   Take a look at SystemOutputStrategy#inferType. 



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java:
##########
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String 
opName) {
             return opName
                     + "(TABLE table_name, MODEL model_name, 
DESCRIPTOR(input_columns), [MAP[]]";

Review Comment:
   Do you miss the `)` here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to