This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new b181a9f099 [feature](Nereids) support array type in fold constant
framework (#23373)
b181a9f099 is described below
commit b181a9f09905aea17f312cc93a837c0203892903
Author: morrySnow <[email protected]>
AuthorDate: Mon Aug 28 10:47:43 2023 +0800
[feature](Nereids) support array type in fold constant framework (#23373)
1. use legacy planner way to process constant folding result from be
2. support signature with complex type for constant folding on fe
---
.../expression/rules/FoldConstantRuleOnBE.java | 62 +++++++++++++++++-----
.../trees/expressions/ExpressionEvaluator.java | 46 +++++++++++-----
2 files changed, 83 insertions(+), 25 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
index f381f328fd..6ed045a300 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
@@ -19,12 +19,12 @@ package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.ExprId;
-import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.PrimitiveType;
-import org.apache.doris.catalog.Type;
+import org.apache.doris.catalog.ScalarType;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.UserException;
+import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
@@ -33,9 +33,15 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DecimalV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.PConstantExprResult;
+import org.apache.doris.proto.Types.PScalarType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.system.Backend;
@@ -46,6 +52,7 @@ import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.thrift.TQueryGlobals;
import org.apache.doris.thrift.TQueryOptions;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -69,8 +76,8 @@ public class FoldConstantRuleOnBE extends
AbstractExpressionRewriteRule {
private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
@Override
- public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
- Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr,
ctx);
+ public Expression rewrite(Expression expression, ExpressionRewriteContext
ctx) {
+ expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx);
return foldByBE(expression, ctx);
}
@@ -175,26 +182,57 @@ public class FoldConstantRuleOnBE extends
AbstractExpressionRewriteRule {
if (result.getStatus().getStatusCode() == 0) {
for (Entry<String, InternalService.PExprResultMap> e :
result.getExprResultMapMap().entrySet()) {
for (Entry<String, InternalService.PExprResult> e1 :
e.getValue().getMapMap().entrySet()) {
+ PScalarType pScalarType = e1.getValue().getType();
+ TPrimitiveType tPrimitiveType =
TPrimitiveType.findByValue(pScalarType.getType());
+ PrimitiveType primitiveType =
PrimitiveType.fromThrift(Objects.requireNonNull(tPrimitiveType));
Expression ret;
if (e1.getValue().getSuccess()) {
- TPrimitiveType type =
TPrimitiveType.findByValue(e1.getValue().getType().getType());
- Type t =
Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type)));
- Expr staleExpr =
LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t));
- // Nereids type
- DataType t1 =
DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString());
- ret =
Literal.of(staleExpr.getStringValue()).castTo(t1);
+ DataType type;
+ if (PrimitiveType.ARRAY == primitiveType
+ || PrimitiveType.MAP == primitiveType
+ || PrimitiveType.STRUCT == primitiveType
+ || PrimitiveType.AGG_STATE ==
primitiveType) {
+ ret = constMap.get(e1.getKey());
+ } else {
+ if (primitiveType == PrimitiveType.CHAR) {
+
Preconditions.checkState(pScalarType.hasLen(),
+ "be return char type without len");
+ type =
CharType.createCharType(pScalarType.getLen());
+ } else if (primitiveType ==
PrimitiveType.VARCHAR) {
+
Preconditions.checkState(pScalarType.hasLen(),
+ "be return varchar type without
len");
+ type =
VarcharType.createVarcharType(pScalarType.getLen());
+ } else if (primitiveType ==
PrimitiveType.DECIMALV2) {
+ type = DecimalV2Type.createDecimalV2Type(
+ pScalarType.getPrecision(),
pScalarType.getScale());
+ } else if (primitiveType ==
PrimitiveType.DATETIMEV2) {
+ type =
DateTimeV2Type.of(pScalarType.getScale());
+ } else if (primitiveType ==
PrimitiveType.DECIMAL32
+ || primitiveType ==
PrimitiveType.DECIMAL64
+ || primitiveType ==
PrimitiveType.DECIMAL128) {
+ type = DecimalV3Type.createDecimalV3Type(
+ pScalarType.getPrecision(),
pScalarType.getScale());
+ } else {
+ type =
DataType.fromCatalogType(ScalarType.createType(
+
PrimitiveType.fromThrift(tPrimitiveType)));
+ }
+ ret =
Literal.of(e1.getValue().getContent()).castTo(type);
+ }
} else {
ret = constMap.get(e1.getKey());
}
+ LOG.debug("Be constant folding convert {} to {}",
e1.getKey(), ret);
resultMap.put(e1.getKey(), ret);
}
}
} else {
- LOG.warn("failed to get const expr value from be: {}",
result.getStatus().getErrorMsgsList());
+ LOG.warn("query {} failed to get const expr value from be: {}",
+ DebugUtil.printId(context.queryId()),
result.getStatus().getErrorMsgsList());
}
} catch (Exception e) {
- LOG.warn("failed to get const expr value from be: {}",
e.getMessage());
+ LOG.warn("query {} failed to get const expr value from be: {}",
+ DebugUtil.printId(context.queryId()), e.getMessage());
}
return resultMap;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
index 0e7ef81cc4..2964a4eeb3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
@@ -30,8 +30,11 @@ import
org.apache.doris.nereids.trees.expressions.functions.executable.NumericAr
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.MapType;
+import org.apache.doris.nereids.types.StructType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
@@ -106,9 +109,6 @@ public enum ExpressionEvaluator {
private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers =
functions.get(signature.getName());
- if (functionInvokers == null) {
- return null;
- }
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
@@ -134,9 +134,8 @@ public enum ExpressionEvaluator {
if (functions != null) {
return;
}
- ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder =
- new ImmutableMultimap.Builder<String, FunctionInvoker>();
- List<Class> classes = ImmutableList.of(
+ ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new
ImmutableMultimap.Builder<>();
+ List<Class<?>> classes = ImmutableList.of(
DateTimeAcquire.class,
DateTimeExtractAndTransform.class,
ExecutableFunctions.class,
@@ -144,7 +143,7 @@ public enum ExpressionEvaluator {
DateTimeArithmetic.class,
NumericArithmetic.class
);
- for (Class cls : classes) {
+ for (Class<?> cls : classes) {
for (Method method : cls.getDeclaredMethods()) {
ExecFunctionList annotationList =
method.getAnnotation(ExecFunctionList.class);
if (annotationList != null) {
@@ -165,18 +164,39 @@ public enum ExpressionEvaluator {
DataType returnType =
DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
- if (type.equalsIgnoreCase("DECIMALV3")) {
- argTypes.add(DecimalV3Type.WILDCARD);
- } else {
- argTypes.add(DataType.convertFromString(type));
- }
+
argTypes.add(replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
FunctionSignature signature = new FunctionSignature(name,
- argTypes.toArray(new DataType[argTypes.size()]),
returnType);
+ argTypes.toArray(new DataType[0]), returnType);
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}
+ private DataType replaceDecimalV3WithWildcard(DataType input) {
+ if (input instanceof ArrayType) {
+ DataType item = replaceDecimalV3WithWildcard(((ArrayType)
input).getItemType());
+ if (item == ((ArrayType) input).getItemType()) {
+ return input;
+ }
+ return ArrayType.of(item);
+ } else if (input instanceof MapType) {
+ DataType keyType = replaceDecimalV3WithWildcard(((MapType)
input).getKeyType());
+ DataType valueType = replaceDecimalV3WithWildcard(((MapType)
input).getValueType());
+ if (keyType == ((MapType) input).getKeyType() && valueType ==
((MapType) input).getValueType()) {
+ return input;
+ }
+ return MapType.of(keyType, valueType);
+ } else if (input instanceof StructType) {
+ // TODO: support struct type
+ return input;
+ } else {
+ if (input instanceof DecimalV3Type) {
+ return DecimalV3Type.WILDCARD;
+ }
+ return input;
+ }
+ }
+
/**
* function invoker.
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]