twalthr commented on code in PR #26924:
URL: https://github.com/apache/flink/pull/26924#discussion_r2316405801


##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelSemantics.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.types.DataType;
+
+/**
+ * Provides call information about the model that has been passed to a model 
argument.
+ *
+ * <p>This class is only available for model arguments (i.e. arguments of a 
{@link
+ * ProcessTableFunction} that are annotated with {@code @ArgumentHint(MODEL)}).
+ */
+@PublicEvolving
+public interface ModelSemantics {
+
+    /**
+     * Input data type expected by the passed model. Extracting type from PTF 
class definition is
+     * not supported yet.
+     */
+    DataType inputDataType();
+
+    /**
+     * Output data type produced by the passed model. Extracting type from PTF 
class definition is

Review Comment:
   ```suggestion
        * Output data type produced by the passed model. 
   ```



##########
flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java:
##########
@@ -651,6 +652,12 @@ public Optional<TableSemantics> getTableSemantics(int pos) 
{
             return Optional.of(semantics);
         }
 
+        @Override
+        public Optional<ModelSemantics> getModelSemantics(int pos) {
+            // TODO: Add ModelReferenceExpression checks and 
TableApiModelSemantics

Review Comment:
   We should not leave TODO in the code base, sometimes they stay there forever.



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelTypeUtils.java:
##########
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.DataTypes.Field;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.Signature.Argument;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@Internal
+public class ModelTypeUtils {
+
+    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY =

Review Comment:
   Move this into to 
org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelSemantics.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.types.DataType;
+
+/**
+ * Provides call information about the model that has been passed to a model 
argument.
+ *
+ * <p>This class is only available for model arguments (i.e. arguments of a 
{@link
+ * ProcessTableFunction} that are annotated with {@code @ArgumentHint(MODEL)}).
+ */
+@PublicEvolving
+public interface ModelSemantics {
+
+    /**
+     * Input data type expected by the passed model. Extracting type from PTF 
class definition is

Review Comment:
   ```suggestion
        * Input data type expected by the passed model. 
   ```



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelTypeUtils.java:
##########
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.DataTypes.Field;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.Signature.Argument;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@Internal
+public class ModelTypeUtils {
+
+    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY =
+            new InputTypeStrategy() {
+                @Override
+                public ArgumentCount getArgumentCount() {
+                    return ConstantArgumentCount.between(3, 4);
+                }
+
+                @Override
+                public Optional<List<DataType>> inferInputTypes(
+                        CallContext callContext, boolean throwOnFailure) {
+                    return 
ModelTypeUtils.inferMLPredictInputTypes(callContext, throwOnFailure);
+                }
+
+                @Override
+                public List<Signature> 
getExpectedSignatures(FunctionDefinition definition) {
+                    return List.of(
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR")),
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR"),
+                                    Argument.of("CONFIG", "MAP")));
+                }
+            };
+
+    private static Optional<List<DataType>> inferMLPredictInputTypes(
+            CallContext callContext, boolean throwOnFailure) {
+
+        // Check that first argument is a table
+        TableSemantics tableSemantics = 
callContext.getTableSemantics(0).orElse(null);
+        if (tableSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "First argument must be a table for ML_PREDICT 
function.");

Review Comment:
   An input type strategy is optional if static arguments have been declared. 
You can assume that this check has been done already. An input type strategy 
might only be useful if you want to do additional validation, like 
validateTableAndDescriptorArguments below



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelTypeUtils.java:
##########
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.DataTypes.Field;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.Signature.Argument;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@Internal
+public class ModelTypeUtils {
+
+    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY =
+            new InputTypeStrategy() {
+                @Override
+                public ArgumentCount getArgumentCount() {
+                    return ConstantArgumentCount.between(3, 4);
+                }
+
+                @Override
+                public Optional<List<DataType>> inferInputTypes(
+                        CallContext callContext, boolean throwOnFailure) {
+                    return 
ModelTypeUtils.inferMLPredictInputTypes(callContext, throwOnFailure);
+                }
+
+                @Override
+                public List<Signature> 
getExpectedSignatures(FunctionDefinition definition) {
+                    return List.of(
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR")),
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR"),
+                                    Argument.of("CONFIG", "MAP")));
+                }
+            };
+
+    private static Optional<List<DataType>> inferMLPredictInputTypes(
+            CallContext callContext, boolean throwOnFailure) {
+
+        // Check that first argument is a table
+        TableSemantics tableSemantics = 
callContext.getTableSemantics(0).orElse(null);
+        if (tableSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "First argument must be a table for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that second argument is a model
+        ModelSemantics modelSemantics = 
callContext.getModelSemantics(1).orElse(null);
+        if (modelSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Second argument must be a model for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that third argument is a descriptor with column names
+        Optional<ColumnList> descriptorColumns = 
callContext.getArgumentValue(2, ColumnList.class);
+        if (descriptorColumns.isEmpty()) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Third argument must be a descriptor with simple 
column names for ML_PREDICT function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        if (!validateTableAndDescriptorArguments(
+                tableSemantics, descriptorColumns.get(), throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        if (!validateModelDescriptorCompatibility(
+                tableSemantics, modelSemantics, descriptorColumns.get(), 
throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        // Config map validation is done in 
StreamPhysicalMLPredictTableFunctionRule since
+        // we are not able to get map literal here.
+        return Optional.of(callContext.getArgumentDataTypes());
+    }
+
+    private static boolean validateTableAndDescriptorArguments(
+            TableSemantics tableSemantics, ColumnList descriptorColumns, 
boolean throwOnFailure) {
+
+        // Check that descriptor column names exist in table columns
+        List<Field> tableFields = 
DataType.getFields(tableSemantics.dataType());
+        Set<String> tableFieldNames =
+                
tableFields.stream().map(Field::getName).collect(Collectors.toSet());
+        List<String> descriptorColumnNames = descriptorColumns.getNames();
+
+        for (String descriptorColumnName : descriptorColumnNames) {
+            if (!tableFieldNames.contains(descriptorColumnName)) {
+                if (throwOnFailure) {
+                    throw new ValidationException(
+                            String.format(
+                                    "Descriptor column '%s' not found in table 
columns. "
+                                            + "Available columns: %s.",
+                                    descriptorColumnName, String.join(", ", 
tableFieldNames)));
+                } else {
+                    return false;
+                }
+            }
+        }
+
+        return true;
+    }
+
+    private static boolean validateModelDescriptorCompatibility(
+            TableSemantics tableSemantics,
+            ModelSemantics modelSemantics,
+            ColumnList descriptorColumns,
+            boolean throwOnFailure) {
+
+        // Check descriptor columns match model input size and types
+        DataType modelInputDataType = modelSemantics.inputDataType();
+        LogicalType modelInputLogicalType = 
modelInputDataType.getLogicalType();
+
+        if (!modelInputLogicalType.is(LogicalTypeRoot.ROW)) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Model input type must be a row type for ML_PREDICT 
function.");
+            } else {
+                return false;
+            }
+        }
+
+        List<Field> modelInputFields = DataType.getFields(modelInputDataType);
+        List<String> descriptorColumnNames = descriptorColumns.getNames();
+
+        // Check size compatibility
+        if (descriptorColumnNames.size() != modelInputFields.size()) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        String.format(
+                                "Number of descriptor columns (%d) does not 
match model input size (%d).",
+                                descriptorColumnNames.size(), 
modelInputFields.size()));
+            } else {
+                return false;
+            }
+        }
+
+        // Check type compatibility for each descriptor column
+        List<Field> tableFields = 
DataType.getFields(tableSemantics.dataType());
+        for (int i = 0; i < descriptorColumnNames.size(); i++) {
+            String descriptorColumnName = descriptorColumnNames.get(i);
+
+            // Find the descriptor column's type in the table
+            Field tableField =
+                    tableFields.stream()
+                            .filter(field -> 
field.getName().equals(descriptorColumnName))
+                            .findFirst()
+                            .orElseThrow(
+                                    () ->
+                                            new IllegalStateException(
+                                                    "Column should exist")); 
// Should not happen
+            // due to earlier check
+
+            LogicalType tableColumnType = 
tableField.getDataType().getLogicalType();
+            LogicalType modelInputColumnType =
+                    modelInputFields.get(i).getDataType().getLogicalType();
+
+            if (!LogicalTypeCasts.supportsImplicitCast(tableColumnType, 
modelInputColumnType)) {
+                if (throwOnFailure) {
+                    throw new ValidationException(
+                            String.format(
+                                    "Descriptor column '%s' type %s cannot be 
assigned to model input type %s at position %d.",
+                                    descriptorColumnName,
+                                    tableColumnType,
+                                    modelInputColumnType,
+                                    i));
+                } else {
+                    return false;
+                }
+            }
+        }
+
+        return true;
+    }
+
+    public static final TypeStrategy ML_PREDICT_OUTPUT_TYPE_STRATEGY =

Review Comment:
   Move this to 
org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies



##########
flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/TableSemanticsMock.java:
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.table.connector.ChangelogMode;
+import org.apache.flink.table.functions.TableSemantics;
+import org.apache.flink.table.types.DataType;
+
+import javax.annotation.Nullable;
+
+import java.util.Optional;
+
+/** Mock implementation of {@link TableSemantics} for testing purposes. */
+public class TableSemanticsMock implements TableSemantics {

Review Comment:
   move next to CallContextMock into utils package



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java:
##########
@@ -144,6 +145,45 @@ public static StaticArgument table(
         return new StaticArgument(name, dataType, null, isOptional, 
enrichTableTraits(traits));
     }
 
+    /**
+     * Declares a model argument such as {@code f(m => myModel)} or {@code f(m 
=> MODEL myModel))}.
+     *
+     * <p>By only providing a conversion class, the argument supports a 
"polymorphic" behavior. In
+     * other words: it accepts models with arbitrary schemas or types. For 
this case, a class
+     * satisfying the model's conversion requirements must be used.

Review Comment:
   ```suggestion
        * <p>By using this method, the argument supports a "polymorphic" 
behavior. In
        * other words: it accepts models with arbitrary schemas or types.
   ```
   conversion class does not exist for models yet



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java:
##########
@@ -144,6 +145,45 @@ public static StaticArgument table(
         return new StaticArgument(name, dataType, null, isOptional, 
enrichTableTraits(traits));
     }
 
+    /**
+     * Declares a model argument such as {@code f(m => myModel)} or {@code f(m 
=> MODEL myModel))}.
+     *
+     * <p>By only providing a conversion class, the argument supports a 
"polymorphic" behavior. In
+     * other words: it accepts models with arbitrary schemas or types. For 
this case, a class
+     * satisfying the model's conversion requirements must be used.
+     *
+     * @param name name for the assignment operator e.g. {@code f(myArg => 
myModel)}
+     * @param isOptional whether the argument is optional
+     * @param traits set of {@link StaticArgumentTrait} requiring {@link 
StaticArgumentTrait#MODEL}
+     */
+    public static StaticArgument model(
+            String name, boolean isOptional, EnumSet<StaticArgumentTrait> 
traits) {
+        final EnumSet<StaticArgumentTrait> enrichedTraits = 
EnumSet.copyOf(traits);
+        enrichedTraits.add(StaticArgumentTrait.MODEL);
+        return new StaticArgument(name, null, null, isOptional, 
enrichedTraits);
+    }
+
+    /**
+     * Declares a model argument such as {@code f(m => myModel)} or {@code f(m 
=> MODEL myModel))}.
+     *
+     * <p>By providing a concrete data type, the argument only accepts models 
with corresponding
+     * schema or type structure. The data type must be appropriate for the 
specific model type.
+     *
+     * @param name name for the assignment operator e.g. {@code f(myArg => 
myModel)}
+     * @param dataType explicit type to which the argument is cast if necessary

Review Comment:
   does this refer to the models input or output schema?



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java:
##########
@@ -1327,7 +1326,7 @@ public List<SqlGroupedWindowFunction> 
getAuxiliaryFunctions() {
     public static final SqlFunction SESSION = new SqlSessionTableFunction();
 
     // MODEL TABLE FUNCTIONS
-    public static final SqlFunction ML_PREDICT = new 
SqlMLPredictTableFunction();
+    // public static final SqlFunction ML_PREDICT = new 
SqlMLPredictTableFunction();

Review Comment:
   remove?



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java:
##########
@@ -147,7 +159,9 @@ private static void 
checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
         checkPassThroughColumns(declaredArgs);
 
         final List<StaticArgument> newStaticArgs = new 
ArrayList<>(declaredArgs);
-        newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);

Review Comment:
   This in not entirely correct. A user-defined PTF can implement a 
TypeInference and avoid system args, but this is kind of second-level API.



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java:
##########
@@ -78,6 +79,18 @@ default Optional<TableSemantics> getTableSemantics(int pos) {
         return Optional.empty();
     }
 
+    /**
+     * Returns information about the model that has been passed to a model 
argument.
+     *
+     * <p>This method applies only to {@link ProcessTableFunction}s.
+     *
+     * <p>Semantics are only available for model arguments that are annotated 
with
+     * {@code @ArgumentHint(MODEL)}).

Review Comment:
   ```suggestion
   
   ```
   We haven't don't exposed this to users yet. 



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelSemantics.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.types.DataType;
+
+/**
+ * Provides call information about the model that has been passed to a model 
argument.
+ *
+ * <p>This class is only available for model arguments (i.e. arguments of a 
{@link
+ * ProcessTableFunction} that are annotated with {@code @ArgumentHint(MODEL)}).
+ */
+@PublicEvolving
+public interface ModelSemantics {
+
+    /**
+     * Input data type expected by the passed model. Extracting type from PTF 
class definition is

Review Comment:
   Drop last sentence. It rather confuses.



##########
flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/ModelSemanticsMock.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.table.functions.ModelSemantics;
+import org.apache.flink.table.types.DataType;
+
+/** Mock implementation of {@link ModelSemantics} for testing purposes. */
+public class ModelSemanticsMock implements ModelSemantics {

Review Comment:
   move next to CallContextMock into utils package



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java:
##########
@@ -335,6 +335,8 @@ public RelNode visit(FunctionQueryOperation functionTable) {
                                             inputStack.add(relBuilder.build());
                                             return tableArgCall;
                                         }
+                                        // TODO: Check 
ModelReferenceExpression and construct
+                                        // RexModelArgCall

Review Comment:
   Are you planning to fix this TODO?



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelTypeUtils.java:
##########
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.DataTypes.Field;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.Signature.Argument;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@Internal
+public class ModelTypeUtils {
+
+    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY =
+            new InputTypeStrategy() {
+                @Override
+                public ArgumentCount getArgumentCount() {
+                    return ConstantArgumentCount.between(3, 4);
+                }
+
+                @Override
+                public Optional<List<DataType>> inferInputTypes(
+                        CallContext callContext, boolean throwOnFailure) {
+                    return 
ModelTypeUtils.inferMLPredictInputTypes(callContext, throwOnFailure);
+                }
+
+                @Override
+                public List<Signature> 
getExpectedSignatures(FunctionDefinition definition) {
+                    return List.of(
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR")),
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR"),
+                                    Argument.of("CONFIG", "MAP")));
+                }
+            };
+
+    private static Optional<List<DataType>> inferMLPredictInputTypes(
+            CallContext callContext, boolean throwOnFailure) {
+
+        // Check that first argument is a table
+        TableSemantics tableSemantics = 
callContext.getTableSemantics(0).orElse(null);
+        if (tableSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "First argument must be a table for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that second argument is a model
+        ModelSemantics modelSemantics = 
callContext.getModelSemantics(1).orElse(null);
+        if (modelSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Second argument must be a model for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that third argument is a descriptor with column names
+        Optional<ColumnList> descriptorColumns = 
callContext.getArgumentValue(2, ColumnList.class);
+        if (descriptorColumns.isEmpty()) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Third argument must be a descriptor with simple 
column names for ML_PREDICT function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        if (!validateTableAndDescriptorArguments(
+                tableSemantics, descriptorColumns.get(), throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        if (!validateModelDescriptorCompatibility(
+                tableSemantics, modelSemantics, descriptorColumns.get(), 
throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        // Config map validation is done in 
StreamPhysicalMLPredictTableFunctionRule since
+        // we are not able to get map literal here.
+        return Optional.of(callContext.getArgumentDataTypes());
+    }
+
+    private static boolean validateTableAndDescriptorArguments(
+            TableSemantics tableSemantics, ColumnList descriptorColumns, 
boolean throwOnFailure) {
+
+        // Check that descriptor column names exist in table columns
+        List<Field> tableFields = 
DataType.getFields(tableSemantics.dataType());
+        Set<String> tableFieldNames =
+                
tableFields.stream().map(Field::getName).collect(Collectors.toSet());
+        List<String> descriptorColumnNames = descriptorColumns.getNames();
+
+        for (String descriptorColumnName : descriptorColumnNames) {
+            if (!tableFieldNames.contains(descriptorColumnName)) {
+                if (throwOnFailure) {
+                    throw new ValidationException(
+                            String.format(
+                                    "Descriptor column '%s' not found in table 
columns. "
+                                            + "Available columns: %s.",
+                                    descriptorColumnName, String.join(", ", 
tableFieldNames)));
+                } else {
+                    return false;
+                }
+            }
+        }
+
+        return true;
+    }
+
+    private static boolean validateModelDescriptorCompatibility(
+            TableSemantics tableSemantics,
+            ModelSemantics modelSemantics,
+            ColumnList descriptorColumns,
+            boolean throwOnFailure) {
+
+        // Check descriptor columns match model input size and types
+        DataType modelInputDataType = modelSemantics.inputDataType();
+        LogicalType modelInputLogicalType = 
modelInputDataType.getLogicalType();
+
+        if (!modelInputLogicalType.is(LogicalTypeRoot.ROW)) {

Review Comment:
   Isn't this always the case? Can't the code be simplified here?



##########
flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java:
##########
@@ -192,6 +194,22 @@ void testMissingUid() {
                                         + "For example: myFunction(..., uid => 
'my-id')"));
     }
 
+    @Test
+    void testNoSystemArgsAllowedForScalarPtf() {
+        util.addTemporarySystemFunction("f", NoSystemArgsScalarFunction.class);
+        assertThatThrownBy(() -> util.verifyRelPlan("SELECT * FROM f(i => 
1);"))
+                .satisfies(
+                        anyCauseMatches("Disabling uid/time attributes is not 
supported for PTF."));
+    }
+
+    @Test
+    void testNoSystemArgsAllowedForTablePtf() {
+        util.addTemporarySystemFunction("f", NoSystemArgsTableFunction.class);
+        assertThatThrownBy(() -> util.verifyRelPlan("SELECT * FROM f(r => 
TABLE t, i => 1);"))
+                .satisfies(
+                        anyCauseMatches("Disabling uid/time attributes is not 
supported for PTF."));

Review Comment:
   ```suggestion
                           anyCauseMatches("Disabling system arguments is not 
supported for user-defined PTF yet."));
   ```



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinition.java:
##########
@@ -298,6 +298,11 @@ public Builder staticArguments(StaticArgument... 
staticArguments) {
             return this;
         }
 
+        public Builder allowSystemArguments(boolean allowSystemArguments) {

Review Comment:
   let's call this `disableSystemArguments` and in TypeInference. By default, 
this then can be false.



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/CallBindingCallContext.java:
##########
@@ -62,6 +65,7 @@ public final class CallBindingCallContext extends 
AbstractSqlCallContext {
     private final List<DataType> argumentDataTypes;
     private final @Nullable DataType outputType;
     private final @Nullable List<StaticArgument> staticArguments;
+    private final SqlValidator validator;

Review Comment:
   validator does not really fit in here, can we avoid it? SqlModelCall should 
have resolved types already. I think TableArgCall has the same? we should 
synchronize the two if possible.



##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ModelTypeUtils.java:
##########
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.DataTypes.Field;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.Signature.Argument;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@Internal
+public class ModelTypeUtils {
+
+    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY =
+            new InputTypeStrategy() {
+                @Override
+                public ArgumentCount getArgumentCount() {
+                    return ConstantArgumentCount.between(3, 4);
+                }
+
+                @Override
+                public Optional<List<DataType>> inferInputTypes(
+                        CallContext callContext, boolean throwOnFailure) {
+                    return 
ModelTypeUtils.inferMLPredictInputTypes(callContext, throwOnFailure);
+                }
+
+                @Override
+                public List<Signature> 
getExpectedSignatures(FunctionDefinition definition) {
+                    return List.of(
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR")),
+                            Signature.of(
+                                    Argument.of("TABLE", "ROW"),
+                                    Argument.of("MODEL", "MODEL"),
+                                    Argument.of("ARGS", "DESCRIPTOR"),
+                                    Argument.of("CONFIG", "MAP")));
+                }
+            };
+
+    private static Optional<List<DataType>> inferMLPredictInputTypes(
+            CallContext callContext, boolean throwOnFailure) {
+
+        // Check that first argument is a table
+        TableSemantics tableSemantics = 
callContext.getTableSemantics(0).orElse(null);
+        if (tableSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "First argument must be a table for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that second argument is a model
+        ModelSemantics modelSemantics = 
callContext.getModelSemantics(1).orElse(null);
+        if (modelSemantics == null) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Second argument must be a model for ML_PREDICT 
function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        // Check that third argument is a descriptor with column names
+        Optional<ColumnList> descriptorColumns = 
callContext.getArgumentValue(2, ColumnList.class);
+        if (descriptorColumns.isEmpty()) {
+            if (throwOnFailure) {
+                throw new ValidationException(
+                        "Third argument must be a descriptor with simple 
column names for ML_PREDICT function.");
+            } else {
+                return Optional.empty();
+            }
+        }
+
+        if (!validateTableAndDescriptorArguments(
+                tableSemantics, descriptorColumns.get(), throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        if (!validateModelDescriptorCompatibility(
+                tableSemantics, modelSemantics, descriptorColumns.get(), 
throwOnFailure)) {
+            return Optional.empty();
+        }
+
+        // Config map validation is done in 
StreamPhysicalMLPredictTableFunctionRule since
+        // we are not able to get map literal here.
+        return Optional.of(callContext.getArgumentDataTypes());
+    }
+
+    private static boolean validateTableAndDescriptorArguments(
+            TableSemantics tableSemantics, ColumnList descriptorColumns, 
boolean throwOnFailure) {
+
+        // Check that descriptor column names exist in table columns
+        List<Field> tableFields = 
DataType.getFields(tableSemantics.dataType());

Review Comment:
   use DataType.getFieldNames



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to