This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/udfSemanticCheck
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 80029873b84b634be068142f82f324594bbcaf9f
Author: Minghui Liu <[email protected]>
AuthorDate: Mon Jun 27 17:21:01 2022 +0800

    fix identifyOutputColumnType for expression
---
 .../db/mpp/plan/analyze/ExpressionAnalyzer.java    | 56 +++++++++++++++++-----
 .../plan/expression/multi/FunctionExpression.java  |  2 +-
 .../iotdb/db/mpp/plan/parser/ASTVisitor.java       |  3 +-
 3 files changed, 47 insertions(+), 14 deletions(-)

diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
index 229b6ebd3c..5ef8c22045 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
@@ -107,12 +107,13 @@ public class ExpressionAnalyzer {
    * @param expression expression to be checked
    * @return true if this expression is valid
    */
-  public static ResultColumn.ColumnType identifyOutputColumnType(Expression 
expression) {
+  public static ResultColumn.ColumnType identifyOutputColumnType(
+      Expression expression, boolean isRoot) {
     if (expression instanceof BinaryExpression) {
       ResultColumn.ColumnType leftType =
-          identifyOutputColumnType(((BinaryExpression) 
expression).getLeftExpression());
+          identifyOutputColumnType(((BinaryExpression) 
expression).getLeftExpression(), false);
       ResultColumn.ColumnType rightType =
-          identifyOutputColumnType(((BinaryExpression) 
expression).getRightExpression());
+          identifyOutputColumnType(((BinaryExpression) 
expression).getRightExpression(), false);
       if ((leftType == ResultColumn.ColumnType.RAW
               && rightType == ResultColumn.ColumnType.AGGREGATION)
           || (leftType == ResultColumn.ColumnType.AGGREGATION
@@ -120,7 +121,8 @@ public class ExpressionAnalyzer {
         throw new SemanticException(
             "Raw data and aggregation result hybrid calculation is not 
supported.");
       }
-      if (leftType == ResultColumn.ColumnType.CONSTANT
+      if (isRoot
+          && leftType == ResultColumn.ColumnType.CONSTANT
           && rightType == ResultColumn.ColumnType.CONSTANT) {
         throw new SemanticException("Constant column is not supported.");
       }
@@ -129,18 +131,48 @@ public class ExpressionAnalyzer {
       }
       return rightType;
     } else if (expression instanceof UnaryExpression) {
-      return identifyOutputColumnType(((UnaryExpression) 
expression).getExpression());
+      return identifyOutputColumnType(((UnaryExpression) 
expression).getExpression(), false);
     } else if (expression instanceof FunctionExpression) {
-      if (!expression.isBuiltInAggregationFunctionExpression()) {
-        return ResultColumn.ColumnType.RAW;
-      }
-      for (Expression childExpression : expression.getExpressions()) {
-        if (identifyOutputColumnType(childExpression) == 
ResultColumn.ColumnType.AGGREGATION) {
+      List<Expression> inputExpressions = expression.getExpressions();
+      if (expression.isBuiltInAggregationFunctionExpression()) {
+        for (Expression inputExpression : inputExpressions) {
+          if (identifyOutputColumnType(inputExpression, false)
+              == ResultColumn.ColumnType.AGGREGATION) {
+            throw new SemanticException(
+                "Aggregation results cannot be as input of the aggregation 
function.");
+          }
+        }
+        return ResultColumn.ColumnType.AGGREGATION;
+      } else {
+        ResultColumn.ColumnType checkedType = null;
+        int lastCheckedIndex = 0;
+        for (int i = 0; i < inputExpressions.size(); i++) {
+          ResultColumn.ColumnType columnType =
+              identifyOutputColumnType(inputExpressions.get(i), false);
+          if (columnType != ResultColumn.ColumnType.CONSTANT) {
+            checkedType = columnType;
+            lastCheckedIndex = i;
+            break;
+          }
+        }
+        if (checkedType == null) {
           throw new SemanticException(
-              "Aggregation results cannot be as input of the aggregation 
function.");
+              String.format(
+                  "Input of '%s' is illegal.",
+                  ((FunctionExpression) expression).getFunctionName()));
+        }
+        for (int i = lastCheckedIndex; i < inputExpressions.size(); i++) {
+          ResultColumn.ColumnType columnType =
+              identifyOutputColumnType(inputExpressions.get(i), false);
+          if (columnType != ResultColumn.ColumnType.CONSTANT && columnType != 
checkedType) {
+            throw new SemanticException(
+                String.format(
+                    "Raw data and aggregation result hybrid input of '%s' is 
not supported.",
+                    ((FunctionExpression) expression).getFunctionName()));
+          }
         }
+        return checkedType;
       }
-      return ResultColumn.ColumnType.AGGREGATION;
     } else if (expression instanceof TimeSeriesOperand || expression 
instanceof TimestampOperand) {
       return ResultColumn.ColumnType.RAW;
     } else if (expression instanceof ConstantOperand) {
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index 88d38aebd3..870b91d47f 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -259,7 +259,7 @@ public class FunctionExpression extends Expression {
         expression.inferTypes(typeProvider);
       }
 
-      if (isTimeSeriesGeneratingFunctionExpression()) {
+      if (!isBuiltInAggregationFunctionExpression()) {
         typeProvider.setType(
             expressionString,
             new UDTFTypeInferrer(functionName)
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
index 2dda2f3e2e..68d32d461d 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
@@ -745,7 +745,8 @@ public class ASTVisitor extends 
IoTDBSqlParserBaseVisitor<Statement> {
     if (resultColumnContext.AS() != null) {
       alias = parseAlias(resultColumnContext.alias());
     }
-    ResultColumn.ColumnType columnType = 
ExpressionAnalyzer.identifyOutputColumnType(expression);
+    ResultColumn.ColumnType columnType =
+        ExpressionAnalyzer.identifyOutputColumnType(expression, true);
     return new ResultColumn(expression, alias, columnType);
   }
 

Reply via email to