This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new a4950fe48f Remove SqlKind from FunctionCall (#13293) a4950fe48f is described below commit a4950fe48f9e58a1fc65471d2a66b4c979444f10 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Sun Jun 2 11:31:37 2024 -0700 Remove SqlKind from FunctionCall (#13293) --- .../pinot/common/utils/request/RequestUtils.java | 9 +- pinot-common/src/main/proto/expressions.proto | 9 +- .../query/parser/CalciteRexExpressionParser.java | 220 ++++++--------------- .../planner/logical/RelToPlanNodeConverter.java | 5 +- .../pinot/query/planner/logical/RexExpression.java | 65 ++---- .../query/planner/logical/RexExpressionUtils.java | 84 ++++---- .../serde/ProtoExpressionToRexExpression.java | 6 +- .../serde/RexExpressionToProtoExpression.java | 21 +- .../query/runtime/operator/AggregateOperator.java | 6 +- .../runtime/operator/operands/FunctionOperand.java | 7 +- .../operator/operands/TransformOperandFactory.java | 22 +-- .../runtime/operator/utils/AggregationUtils.java | 5 +- .../runtime/operator/utils/OperatorUtils.java | 58 ------ .../plan/server/ServerPlanRequestVisitor.java | 12 +- .../runtime/operator/AggregateOperatorTest.java | 9 +- .../query/runtime/operator/FilterOperatorTest.java | 15 +- .../runtime/operator/HashJoinOperatorTest.java | 8 +- .../runtime/operator/TransformOperatorTest.java | 12 +- .../operator/WindowAggregateOperatorTest.java | 24 ++- 19 files changed, 191 insertions(+), 406 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java index 7c1795f009..5b5013550e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.math.BigDecimal; import java.util.ArrayList; @@ -408,18 +409,16 @@ public class RequestUtils { private static final Map<String, String> CANONICAL_NAME_TO_SPECIAL_KEY_MAP; static { - CANONICAL_NAME_TO_SPECIAL_KEY_MAP = new HashMap<>(); + ImmutableMap.Builder<String, String> builder = ImmutableMap.builder(); for (FilterKind filterKind : FilterKind.values()) { - CANONICAL_NAME_TO_SPECIAL_KEY_MAP.put(canonicalizeFunctionName(filterKind.name()), filterKind.name()); + builder.put(canonicalizeFunctionName(filterKind.name()), filterKind.name()); } - CANONICAL_NAME_TO_SPECIAL_KEY_MAP.put("stdistance", "st_distance"); + CANONICAL_NAME_TO_SPECIAL_KEY_MAP = builder.build(); } /** * Converts the function name into its canonical form, but preserving the special keys. * - Keep FilterKind.name() as is because we need to read the FilterKind via FilterKind.valueOf(). - * - Keep ST_Distance as is because we use exact match when applying geo-spatial index up to release 0.10.0. - * TODO: Remove the ST_Distance special handling after releasing 0.11.0. */ public static String canonicalizeFunctionNamePreservingSpecialKey(String functionName) { String canonicalName = canonicalizeFunctionName(functionName); diff --git a/pinot-common/src/main/proto/expressions.proto b/pinot-common/src/main/proto/expressions.proto index 37150658bc..ecfe3034cc 100644 --- a/pinot-common/src/main/proto/expressions.proto +++ b/pinot-common/src/main/proto/expressions.proto @@ -64,11 +64,10 @@ message Literal { } message FunctionCall { - int32 sqlKind = 1; - ColumnDataType dataType = 2; - string functionName = 3; - repeated RexExpression functionOperands = 4; - bool isDistinct = 5; + ColumnDataType dataType = 1; + string functionName = 2; + repeated RexExpression functionOperands = 3; + bool isDistinct = 4; } message RexExpression { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index 1862adf95e..f61e9edd3e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -19,28 +19,16 @@ package org.apache.pinot.query.parser; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Iterator; import java.util.List; -import java.util.Map; import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rel.RelFieldCollation.NullDirection; -import org.apache.calcite.sql.SqlKind; import org.apache.pinot.common.request.Expression; -import org.apache.pinot.common.request.ExpressionType; -import org.apache.pinot.common.request.Function; import org.apache.pinot.common.request.PinotQuery; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.SortNode; -import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.utils.ByteArray; -import org.apache.pinot.sql.FilterKind; import org.apache.pinot.sql.parsers.ParserUtils; -import org.apache.pinot.sql.parsers.SqlCompilationException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** @@ -54,69 +42,45 @@ public class CalciteRexExpressionParser { private CalciteRexExpressionParser() { } - private static final Logger LOGGER = LoggerFactory.getLogger(CalciteRexExpressionParser.class); - private static final Map<String, String> CANONICAL_NAME_TO_SPECIAL_KEY_MAP; - private static final String ARRAY_TO_MV_FUNCTION_NAME = "arraytomv"; - - static { - CANONICAL_NAME_TO_SPECIAL_KEY_MAP = new HashMap<>(); - // adding filter kind special handling - for (FilterKind filterKind : FilterKind.values()) { - CANONICAL_NAME_TO_SPECIAL_KEY_MAP.put(RequestUtils.canonicalizeFunctionName(filterKind.name()), - filterKind.name()); - } - // adding SqlKind.OTHERS and SqlKind.OTHER_FUNCTIONS that have canonical names. - CANONICAL_NAME_TO_SPECIAL_KEY_MAP.put("||", "concat"); - } + // The following function names are canonical names. + private static final String AND = "AND"; + private static final String OR = "OR"; + private static final String FILTER = "filter"; + private static final String ASC = "asc"; + private static final String DESC = "desc"; + private static final String NULLS_FIRST = "nullsfirst"; + private static final String NULLS_LAST = "nullslast"; + private static final String COUNT = "count"; + private static final String ARRAY_TO_MV = "arraytomv"; // -------------------------------------------------------------------------- // Relational conversion Utils // -------------------------------------------------------------------------- - public static List<Expression> convertProjectList(List<RexExpression> projectList, PinotQuery pinotQuery) { - List<Expression> selectExpr = new ArrayList<>(); - final Iterator<RexExpression> iterator = projectList.iterator(); - while (iterator.hasNext()) { - final RexExpression next = iterator.next(); - selectExpr.add(toExpression(next, pinotQuery)); + public static List<Expression> convertRexNodes(List<RexExpression> rexNodes, PinotQuery pinotQuery) { + List<Expression> expressions = new ArrayList<>(rexNodes.size()); + for (RexExpression rexNode : rexNodes) { + expressions.add(toExpression(rexNode, pinotQuery)); } - return selectExpr; + return expressions; } - public static List<Expression> convertAggregateList(List<Expression> groupSetList, List<RexExpression> aggCallList, + public static List<Expression> convertAggregateList(List<Expression> groupByList, List<RexExpression> aggCallList, List<Integer> filterArgIndices, PinotQuery pinotQuery) { - List<Expression> selectExpr = new ArrayList<>(groupSetList); - - for (int idx = 0; idx < aggCallList.size(); idx++) { - final RexExpression aggCall = aggCallList.get(idx); - int filterArgIdx = filterArgIndices.get(idx); + int numAggCalls = aggCallList.size(); + List<Expression> expressions = new ArrayList<>(groupByList.size() + numAggCalls); + expressions.addAll(groupByList); + for (int i = 0; i < numAggCalls; i++) { + Expression aggFunction = toExpression(aggCallList.get(i), pinotQuery); + int filterArgIdx = filterArgIndices.get(i); if (filterArgIdx == -1) { - selectExpr.add(toExpression(aggCall, pinotQuery)); + expressions.add(aggFunction); } else { - selectExpr.add(toExpression(new RexExpression.FunctionCall(SqlKind.FILTER, aggCall.getDataType(), "FILTER", - Arrays.asList(aggCall, new RexExpression.InputRef(filterArgIdx))), pinotQuery)); + expressions.add( + RequestUtils.getFunctionExpression(FILTER, aggFunction, pinotQuery.getSelectList().get(filterArgIdx))); } } - - return selectExpr; - } - - public static List<Expression> convertGroupByList(List<RexExpression> rexNodeList, PinotQuery pinotQuery) { - List<Expression> groupByExpr = new ArrayList<>(); - - final Iterator<RexExpression> iterator = rexNodeList.iterator(); - while (iterator.hasNext()) { - final RexExpression next = iterator.next(); - groupByExpr.add(toExpression(next, pinotQuery)); - } - - return groupByExpr; - } - - private static List<Expression> convertDistinctSelectList(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) { - List<Expression> selectExpr = new ArrayList<>(); - selectExpr.add(convertDistinctAndSelectListToFunctionExpression(rexCall, pinotQuery)); - return selectExpr; + return expressions; } public static List<Expression> convertOrderByList(SortNode node, PinotQuery pinotQuery) { @@ -134,65 +98,28 @@ public class CalciteRexExpressionParser { private static Expression convertOrderBy(RexExpression rexNode, Direction direction, NullDirection nullDirection, PinotQuery pinotQuery) { + Expression expression = toExpression(rexNode, pinotQuery); if (direction == Direction.ASCENDING) { - Expression expression = getFunctionExpression("asc"); - expression.getFunctionCall().addToOperands(toExpression(rexNode, pinotQuery)); + Expression asc = RequestUtils.getFunctionExpression(ASC, expression); // NOTE: Add explicit NULL direction only if it is not the default behavior (default behavior treats NULL as the // largest value) - if (nullDirection == NullDirection.FIRST) { - Expression nullFirstExpression = getFunctionExpression("nullsfirst"); - nullFirstExpression.getFunctionCall().addToOperands(expression); - return nullFirstExpression; - } else { - return expression; - } + return nullDirection == NullDirection.FIRST ? RequestUtils.getFunctionExpression(NULLS_FIRST, asc) : asc; } else { - Expression expression = getFunctionExpression("desc"); - expression.getFunctionCall().addToOperands(toExpression(rexNode, pinotQuery)); + Expression desc = RequestUtils.getFunctionExpression(DESC, expression); // NOTE: Add explicit NULL direction only if it is not the default behavior (default behavior treats NULL as the // largest value) - if (nullDirection == NullDirection.LAST) { - Expression nullLastExpression = getFunctionExpression("nullslast"); - nullLastExpression.getFunctionCall().addToOperands(expression); - return nullLastExpression; - } else { - return expression; - } - } - } - - private static Expression convertDistinctAndSelectListToFunctionExpression(RexExpression.FunctionCall rexCall, - PinotQuery pinotQuery) { - Expression functionExpression = getFunctionExpression("distinct"); - for (RexExpression node : rexCall.getFunctionOperands()) { - Expression columnExpression = toExpression(node, pinotQuery); - if (columnExpression.getType() == ExpressionType.IDENTIFIER && columnExpression.getIdentifier().getName() - .equals("*")) { - throw new SqlCompilationException( - "Syntax error: Pinot currently does not support DISTINCT with *. Please specify each column name after " - + "DISTINCT keyword"); - } else if (columnExpression.getType() == ExpressionType.FUNCTION) { - Function functionCall = columnExpression.getFunctionCall(); - String function = functionCall.getOperator(); - if (AggregationFunctionType.isAggregationFunction(function)) { - throw new SqlCompilationException( - "Syntax error: Use of DISTINCT with aggregation functions is not supported"); - } - } - functionExpression.getFunctionCall().addToOperands(columnExpression); + return nullDirection == NullDirection.LAST ? RequestUtils.getFunctionExpression(NULLS_LAST, desc) : desc; } - return functionExpression; } public static Expression toExpression(RexExpression rexNode, PinotQuery pinotQuery) { - LOGGER.debug("Current processing RexNode: {}, node.getKind(): {}", rexNode, rexNode.getKind()); - switch (rexNode.getKind()) { - case INPUT_REF: - return inputRefToIdentifier((RexExpression.InputRef) rexNode, pinotQuery); - case LITERAL: - return compileLiteralExpression(((RexExpression.Literal) rexNode).getValue()); - default: - return compileFunctionExpression((RexExpression.FunctionCall) rexNode, pinotQuery); + if (rexNode instanceof RexExpression.InputRef) { + return inputRefToIdentifier((RexExpression.InputRef) rexNode, pinotQuery); + } else if (rexNode instanceof RexExpression.Literal) { + return compileLiteralExpression(((RexExpression.Literal) rexNode).getValue()); + } else { + assert rexNode instanceof RexExpression.FunctionCall; + return compileFunctionExpression((RexExpression.FunctionCall) rexNode, pinotQuery); } } @@ -213,41 +140,24 @@ public class CalciteRexExpressionParser { } private static Expression compileFunctionExpression(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) { - SqlKind functionKind = rexCall.getKind(); - String functionName; - switch (functionKind) { - case AND: - return compileAndExpression(rexCall, pinotQuery); - case OR: - return compileOrExpression(rexCall, pinotQuery); - case OTHER: - case OTHER_FUNCTION: - functionName = canonicalizeFunctionName(rexCall.getFunctionName()); - // Special handle for leaf stage multi-value columns, as the default behavior for filter and group by is not - // sql standard, so need to use `array_to_mv` to convert the array to v1 multi-value column for behavior - // consistency meanwhile not violating the sql standard. - if (ARRAY_TO_MV_FUNCTION_NAME.equals(functionName)) { - return toExpression(rexCall.getFunctionOperands().get(0), pinotQuery); - } - break; - default: - functionName = canonicalizeFunctionName(functionKind.name()); - break; + String functionName = rexCall.getFunctionName(); + if (functionName.equals(AND)) { + return compileAndExpression(rexCall, pinotQuery); + } + if (functionName.equals(OR)) { + return compileOrExpression(rexCall, pinotQuery); } + String canonicalName = RequestUtils.canonicalizeFunctionNamePreservingSpecialKey(functionName); List<RexExpression> childNodes = rexCall.getFunctionOperands(); - List<Expression> operands = new ArrayList<>(childNodes.size()); - for (RexExpression childNode : childNodes) { - operands.add(toExpression(childNode, pinotQuery)); + if (canonicalName.equals(COUNT) && childNodes.isEmpty()) { + return RequestUtils.getFunctionExpression(COUNT, RequestUtils.getIdentifierExpression("*")); } - // for COUNT, add a star (*) identifier to operand list b/c V1 doesn't handle empty operand functions. - if (functionKind == SqlKind.COUNT && operands.isEmpty()) { - operands.add(RequestUtils.getIdentifierExpression("*")); - } else { - ParserUtils.validateFunction(functionName, operands); + if (canonicalName.equals(ARRAY_TO_MV)) { + return toExpression(childNodes.get(0), pinotQuery); } - Expression functionExpression = getFunctionExpression(functionName); - functionExpression.getFunctionCall().setOperands(operands); - return functionExpression; + List<Expression> operands = convertRexNodes(childNodes, pinotQuery); + ParserUtils.validateFunction(canonicalName, operands); + return RequestUtils.getFunctionExpression(canonicalName, operands); } /** @@ -256,16 +166,15 @@ public class CalciteRexExpressionParser { private static Expression compileAndExpression(RexExpression.FunctionCall andNode, PinotQuery pinotQuery) { List<Expression> operands = new ArrayList<>(); for (RexExpression childNode : andNode.getFunctionOperands()) { - if (childNode.getKind() == SqlKind.AND) { + if (childNode instanceof RexExpression.FunctionCall && ((RexExpression.FunctionCall) childNode).getFunctionName() + .equals(AND)) { Expression childAndExpression = compileAndExpression((RexExpression.FunctionCall) childNode, pinotQuery); operands.addAll(childAndExpression.getFunctionCall().getOperands()); } else { operands.add(toExpression(childNode, pinotQuery)); } } - Expression andExpression = getFunctionExpression(SqlKind.AND.name()); - andExpression.getFunctionCall().setOperands(operands); - return andExpression; + return RequestUtils.getFunctionExpression(AND, operands); } /** @@ -274,27 +183,14 @@ public class CalciteRexExpressionParser { private static Expression compileOrExpression(RexExpression.FunctionCall orNode, PinotQuery pinotQuery) { List<Expression> operands = new ArrayList<>(); for (RexExpression childNode : orNode.getFunctionOperands()) { - if (childNode.getKind() == SqlKind.OR) { + if (childNode instanceof RexExpression.FunctionCall && ((RexExpression.FunctionCall) childNode).getFunctionName() + .equals(OR)) { Expression childAndExpression = compileOrExpression((RexExpression.FunctionCall) childNode, pinotQuery); operands.addAll(childAndExpression.getFunctionCall().getOperands()); } else { operands.add(toExpression(childNode, pinotQuery)); } } - Expression andExpression = getFunctionExpression(SqlKind.OR.name()); - andExpression.getFunctionCall().setOperands(operands); - return andExpression; - } - - private static Expression getFunctionExpression(String canonicalName) { - Expression expression = new Expression(ExpressionType.FUNCTION); - Function function = new Function(canonicalName); - expression.setFunctionCall(function); - return expression; - } - - private static String canonicalizeFunctionName(String functionName) { - String canonicalizeName = RequestUtils.canonicalizeFunctionName(functionName); - return CANONICAL_NAME_TO_SPECIAL_KEY_MAP.getOrDefault(canonicalizeName, canonicalizeName); + return RequestUtils.getFunctionExpression(OR, operands); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index b9bda8b144..e0a0dfd3a1 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -170,7 +170,7 @@ public final class RelToPlanNodeConverter { private static PlanNode convertLogicalProject(LogicalProject node, int currentStageId) { return new ProjectNode(currentStageId, toDataSchema(node.getRowType()), - node.getProjects().stream().map(RexExpressionUtils::fromRexNode).collect(Collectors.toList())); + RexExpressionUtils.fromRexNodes(node.getProjects())); } private static PlanNode convertLogicalFilter(LogicalFilter node, int currentStageId) { @@ -196,8 +196,7 @@ public final class RelToPlanNodeConverter { // Parse out all equality JOIN conditions JoinInfo joinInfo = node.analyzeCondition(); JoinNode.JoinKeys joinKeys = new JoinNode.JoinKeys(joinInfo.leftKeys, joinInfo.rightKeys); - List<RexExpression> joinClause = - joinInfo.nonEquiConditions.stream().map(RexExpressionUtils::fromRexNode).collect(Collectors.toList()); + List<RexExpression> joinClause = RexExpressionUtils.fromRexNodes(joinInfo.nonEquiConditions); return new JoinNode(currentStageId, toDataSchema(node.getRowType()), toDataSchema(node.getLeft().getRowType()), toDataSchema(node.getRight().getRowType()), joinType, joinKeys, joinClause, node.getHints()); } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java index d2ddfe00f7..da5185ef55 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java @@ -20,7 +20,6 @@ package org.apache.pinot.query.planner.logical; import java.util.List; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlKind; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.query.planner.serde.ProtoProperties; import org.checkerframework.checker.nullness.qual.Nullable; @@ -31,13 +30,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; */ public interface RexExpression { - SqlKind getKind(); - - ColumnDataType getDataType(); - class InputRef implements RexExpression { - @ProtoProperties - private SqlKind _sqlKind; @ProtoProperties private int _index; @@ -45,28 +38,15 @@ public interface RexExpression { } public InputRef(int index) { - _sqlKind = SqlKind.INPUT_REF; _index = index; } public int getIndex() { return _index; } - - @Override - public SqlKind getKind() { - return _sqlKind; - } - - @Override - public ColumnDataType getDataType() { - throw new IllegalStateException("InputRef does not have data type"); - } } class Literal implements RexExpression { - @ProtoProperties - private SqlKind _sqlKind; @ProtoProperties private ColumnDataType _dataType; @ProtoProperties @@ -79,32 +59,21 @@ public interface RexExpression { * NOTE: Value is the internal stored value for the data type. E.g. BOOLEAN -> int, TIMESTAMP -> long. */ public Literal(ColumnDataType dataType, @Nullable Object value) { - _sqlKind = SqlKind.LITERAL; _dataType = dataType; _value = value; } - public Object getValue() { - return _value; - } - - @Override - public SqlKind getKind() { - return _sqlKind; - } - - @Override public ColumnDataType getDataType() { return _dataType; } + + @Nullable + public Object getValue() { + return _value; + } } class FunctionCall implements RexExpression { - // the underlying SQL operator kind of this function. - // It can be either a standard SQL operator or an extended function kind. - // @see #SqlKind.FUNCTION, #SqlKind.OTHER, #SqlKind.OTHER_FUNCTION - @ProtoProperties - private SqlKind _sqlKind; // the return data type of the function. @ProtoProperties private ColumnDataType _dataType; @@ -121,20 +90,22 @@ public interface RexExpression { public FunctionCall() { } - public FunctionCall(SqlKind sqlKind, ColumnDataType dataType, String functionName, - List<RexExpression> functionOperands) { - this(sqlKind, dataType, functionName, functionOperands, false); + public FunctionCall(ColumnDataType dataType, String functionName, List<RexExpression> functionOperands) { + this(dataType, functionName, functionOperands, false); } - public FunctionCall(SqlKind sqlKind, ColumnDataType dataType, String functionName, - List<RexExpression> functionOperands, boolean isDistinct) { - _sqlKind = sqlKind; + public FunctionCall(ColumnDataType dataType, String functionName, List<RexExpression> functionOperands, + boolean isDistinct) { _dataType = dataType; _functionName = functionName; _functionOperands = functionOperands; _isDistinct = isDistinct; } + public ColumnDataType getDataType() { + return _dataType; + } + public String getFunctionName() { return _functionName; } @@ -146,15 +117,5 @@ public interface RexExpression { public boolean isDistinct() { return _isDistinct; } - - @Override - public SqlKind getKind() { - return _sqlKind; - } - - @Override - public ColumnDataType getDataType() { - return _dataType; - } } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java index c2e9890358..82b3b45b7f 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java @@ -28,15 +28,14 @@ import java.util.GregorianCalendar; import java.util.Iterator; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.util.NlsString; import org.apache.calcite.util.Sarg; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -61,6 +60,14 @@ public class RexExpressionUtils { } } + public static List<RexExpression> fromRexNodes(List<RexNode> rexNodes) { + List<RexExpression> rexExpressions = new ArrayList<>(rexNodes.size()); + for (RexNode rexNode : rexNodes) { + rexExpressions.add(fromRexNode(rexNode)); + } + return rexExpressions; + } + public static RexExpression.InputRef fromRexInputRef(RexInputRef rexInputRef) { return new RexExpression.InputRef(rexInputRef.getIndex()); } @@ -98,9 +105,7 @@ public class RexExpressionUtils { } public static RexExpression fromRexCall(RexCall rexCall) { - switch (rexCall.getKind()) { - case CASE: - return handleCase(rexCall); + switch (rexCall.op.kind) { case CAST: return handleCast(rexCall); case REINTERPRET: @@ -108,54 +113,53 @@ public class RexExpressionUtils { case SEARCH: return handleSearch(rexCall); default: - List<RexExpression> operands = - rexCall.getOperands().stream().map(RexExpressionUtils::fromRexNode).collect(Collectors.toList()); - return new RexExpression.FunctionCall(rexCall.getKind(), - RelToPlanNodeConverter.convertToColumnDataType(rexCall.getType()), rexCall.getOperator().getName(), - operands); + return new RexExpression.FunctionCall(RelToPlanNodeConverter.convertToColumnDataType(rexCall.type), + getFunctionName(rexCall.op), fromRexNodes(rexCall.operands)); } } - private static RexExpression.FunctionCall handleCase(RexCall rexCall) { - List<RexExpression> operands = - rexCall.getOperands().stream().map(RexExpressionUtils::fromRexNode).collect(Collectors.toList()); - return new RexExpression.FunctionCall(SqlKind.CASE, - RelToPlanNodeConverter.convertToColumnDataType(rexCall.getType()), "caseWhen", operands); + private static String getFunctionName(SqlOperator operator) { + switch (operator.kind) { + case OTHER: + // NOTE: SqlStdOperatorTable.CONCAT has OTHER kind and "||" as name + return operator.getName().equals("||") ? "CONCAT" : operator.getName(); + case OTHER_FUNCTION: + return operator.getName(); + default: + return operator.kind.name(); + } } private static RexExpression.FunctionCall handleCast(RexCall rexCall) { // CAST is being rewritten into "rexCall.CAST<targetType>(inputValue)", // - e.g. result type has already been converted into the CAST RexCall, so we assert single operand. - List<RexExpression> operands = - rexCall.getOperands().stream().map(RexExpressionUtils::fromRexNode).collect(Collectors.toList()); - Preconditions.checkState(operands.size() == 1, "CAST takes exactly 2 arguments"); - RelDataType castType = rexCall.getType(); - operands.add(new RexExpression.Literal(ColumnDataType.STRING, - RelToPlanNodeConverter.convertToColumnDataType(castType).name())); - return new RexExpression.FunctionCall(SqlKind.CAST, RelToPlanNodeConverter.convertToColumnDataType(castType), - "CAST", operands); + assert rexCall.operands.size() == 1; + List<RexExpression> operands = new ArrayList<>(2); + operands.add(fromRexNode(rexCall.operands.get(0))); + ColumnDataType castType = RelToPlanNodeConverter.convertToColumnDataType(rexCall.type); + operands.add(new RexExpression.Literal(ColumnDataType.STRING, castType.name())); + return new RexExpression.FunctionCall(castType, SqlKind.CAST.name(), operands); } /** * Reinterpret is a pass-through function that does not change the type of the input. */ private static RexExpression handleReinterpret(RexCall rexCall) { - List<RexNode> operands = rexCall.getOperands(); - Preconditions.checkState(operands.size() == 1, "REINTERPRET takes only 1 argument"); - return fromRexNode(operands.get(0)); + assert rexCall.operands.size() == 1; + return fromRexNode(rexCall.operands.get(0)); } private static RexExpression handleSearch(RexCall rexCall) { - List<RexNode> operands = rexCall.getOperands(); - RexInputRef rexInputRef = (RexInputRef) operands.get(0); - RexLiteral rexLiteral = (RexLiteral) operands.get(1); + assert rexCall.operands.size() == 2; + RexInputRef rexInputRef = (RexInputRef) rexCall.operands.get(0); + RexLiteral rexLiteral = (RexLiteral) rexCall.operands.get(1); ColumnDataType dataType = RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType()); Sarg sarg = rexLiteral.getValueAs(Sarg.class); if (sarg.isPoints()) { - return new RexExpression.FunctionCall(SqlKind.IN, dataType, SqlKind.IN.name(), + return new RexExpression.FunctionCall(dataType, SqlKind.IN.name(), toFunctionOperands(rexInputRef, sarg.rangeSet.asRanges(), dataType)); } else if (sarg.isComplementedPoints()) { - return new RexExpression.FunctionCall(SqlKind.NOT_IN, dataType, SqlKind.NOT_IN.name(), + return new RexExpression.FunctionCall(dataType, SqlKind.NOT_IN.name(), toFunctionOperands(rexInputRef, sarg.rangeSet.complement().asRanges(), dataType)); } else { Set<Range<?>> ranges = sarg.rangeSet.asRanges(); @@ -190,7 +194,7 @@ public class RexExpressionUtils { } } ImmutableList<RexExpression> operands = ImmutableList.of(result, newExp); - result = new RexExpression.FunctionCall(SqlKind.OR, ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands); + result = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands); } return result; } @@ -212,7 +216,7 @@ public class RexExpressionUtils { RexExpression upperConstraint = convertUpperBound(rexInput, dataType, range.upperBoundType(), range.upperEndpoint()); ImmutableList<RexExpression> operands = ImmutableList.of(lowerConstraint, upperConstraint); - return new RexExpression.FunctionCall(SqlKind.AND, ColumnDataType.BOOLEAN, SqlKind.AND.name(), operands); + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.AND.name(), operands); } } @@ -221,7 +225,7 @@ public class RexExpressionUtils { SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.GREATER_THAN : SqlKind.GREATER_THAN_OR_EQUAL; RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint)); ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, literal); - return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, sqlKind.name(), operands); + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), operands); } private static RexExpression convertUpperBound(RexExpression.InputRef inputRef, ColumnDataType dataType, @@ -229,7 +233,7 @@ public class RexExpressionUtils { SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.LESS_THAN : SqlKind.LESS_THAN_OR_EQUAL; RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint)); ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, literal); - return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, sqlKind.name(), operands); + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), operands); } /** @@ -246,13 +250,9 @@ public class RexExpressionUtils { } public static RexExpression fromAggregateCall(AggregateCall aggregateCall) { - List<RexExpression> operands = new ArrayList<>(aggregateCall.rexList.size()); - for (RexNode rexNode : aggregateCall.rexList) { - operands.add(fromRexNode(rexNode)); - } - return new RexExpression.FunctionCall(aggregateCall.getAggregation().getKind(), - RelToPlanNodeConverter.convertToColumnDataType(aggregateCall.getType()), - aggregateCall.getAggregation().getName(), operands, aggregateCall.isDistinct()); + return new RexExpression.FunctionCall(RelToPlanNodeConverter.convertToColumnDataType(aggregateCall.type), + getFunctionName(aggregateCall.getAggregation()), fromRexNodes(aggregateCall.rexList), + aggregateCall.isDistinct()); } public static List<RexExpression> fromInputRefs(Iterable<Integer> inputRefs) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java index b3fc186a8c..21f22cf062 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java @@ -20,7 +20,6 @@ package org.apache.pinot.query.planner.serde; import java.util.List; import java.util.stream.Collectors; -import org.apache.calcite.sql.SqlKind; import org.apache.commons.lang3.SerializationUtils; import org.apache.pinot.common.proto.Expressions; import org.apache.pinot.common.utils.DataSchema; @@ -57,9 +56,8 @@ public class ProtoExpressionToRexExpression { List<RexExpression> functionOperands = functionCall.getFunctionOperandsList().stream().map(ProtoExpressionToRexExpression::process) .collect(Collectors.toList()); - return new RexExpression.FunctionCall(SqlKind.values()[functionCall.getSqlKind()], - convertColumnDataType(functionCall.getDataType()), functionCall.getFunctionName(), functionOperands, - functionCall.getIsDistinct()); + return new RexExpression.FunctionCall(convertColumnDataType(functionCall.getDataType()), + functionCall.getFunctionName(), functionOperands, functionCall.getIsDistinct()); } private static RexExpression deserializeLiteral(Expressions.Literal literal) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java index e8220a7289..0c5f82a95b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java @@ -20,8 +20,8 @@ package org.apache.pinot.query.planner.serde; import com.google.protobuf.ByteString; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; import org.apache.commons.lang3.SerializationUtils; import org.apache.pinot.common.proto.Expressions; import org.apache.pinot.common.utils.DataSchema; @@ -41,11 +41,10 @@ public class RexExpressionToProtoExpression { return serializeInputRef((RexExpression.InputRef) expression); } else if (expression instanceof RexExpression.Literal) { return serializeLiteral((RexExpression.Literal) expression); - } else if (expression instanceof RexExpression.FunctionCall) { + } else { + assert expression instanceof RexExpression.FunctionCall; return serializeFunctionCall((RexExpression.FunctionCall) expression); } - - throw new RuntimeException(String.format("Unknown Type Expression Type: %s", expression.getKind())); } private static Expressions.RexExpression serializeInputRef(RexExpression.InputRef inputRef) { @@ -54,15 +53,15 @@ public class RexExpressionToProtoExpression { } private static Expressions.RexExpression serializeFunctionCall(RexExpression.FunctionCall functionCall) { - List<Expressions.RexExpression> functionOperands = - functionCall.getFunctionOperands().stream().map(RexExpressionToProtoExpression::process) - .collect(Collectors.toList()); + List<RexExpression> operands = functionCall.getFunctionOperands(); + List<Expressions.RexExpression> protoOperands = new ArrayList<>(operands.size()); + for (RexExpression operand : operands) { + protoOperands.add(process(operand)); + } Expressions.FunctionCall.Builder protoFunctionCallBuilder = - Expressions.FunctionCall.newBuilder().setSqlKind(functionCall.getKind().ordinal()) - .setDataType(convertColumnDataType(functionCall.getDataType())) - .setFunctionName(functionCall.getFunctionName()).addAllFunctionOperands(functionOperands) + Expressions.FunctionCall.newBuilder().setDataType(convertColumnDataType(functionCall.getDataType())) + .setFunctionName(functionCall.getFunctionName()).addAllFunctionOperands(protoOperands) .setIsDistinct(functionCall.isDistinct()); - return Expressions.RexExpression.newBuilder().setFunctionCall(protoFunctionCallBuilder).build(); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index a19ff64d4e..9d865088fc 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -26,7 +26,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import javax.annotation.Nullable; -import org.apache.calcite.sql.SqlKind; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.datatable.StatMap; import org.apache.pinot.common.request.context.ExpressionContext; @@ -249,10 +248,7 @@ public class AggregateOperator extends MultiStageOperator { int numKeys = groupSet.size(); int[] groupKeyIds = new int[numKeys]; for (int i = 0; i < numKeys; i++) { - RexExpression rexExp = groupSet.get(i); - Preconditions.checkState(rexExp.getKind() == SqlKind.INPUT_REF, "Group key must be an input reference, got: %s", - rexExp.getKind()); - groupKeyIds[i] = ((RexExpression.InputRef) rexExp).getIndex(); + groupKeyIds[i] = ((RexExpression.InputRef) groupSet.get(i)).getIndex(); } return groupKeyIds; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index cccc065be3..89c5585a31 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -43,12 +43,13 @@ public class FunctionOperand implements TransformOperand { private final List<TransformOperand> _operands; private final Object[] _reusableOperandHolder; - public FunctionOperand(RexExpression.FunctionCall functionCall, String canonicalName, DataSchema dataSchema) { + public FunctionOperand(RexExpression.FunctionCall functionCall, DataSchema dataSchema) { _resultType = functionCall.getDataType(); List<RexExpression> operands = functionCall.getFunctionOperands(); int numOperands = operands.size(); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, numOperands); - Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); + FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionCall.getFunctionName(), numOperands); + Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", + functionCall.getFunctionName()); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) { Class<?>[] parameterClasses = _functionInvoker.getParameterClasses(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java index 8ec4b055ea..1c39351bce 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java @@ -22,7 +22,6 @@ import com.google.common.base.Preconditions; import java.util.List; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; -import org.apache.pinot.query.runtime.operator.utils.OperatorUtils; public class TransformOperandFactory { @@ -44,8 +43,7 @@ public class TransformOperandFactory { private static TransformOperand getTransformOperand(RexExpression.FunctionCall functionCall, DataSchema dataSchema) { List<RexExpression> operands = functionCall.getFunctionOperands(); int numOperands = operands.size(); - String canonicalName = OperatorUtils.canonicalizeFunctionName(functionCall.getFunctionName()); - switch (canonicalName) { + switch (functionCall.getFunctionName()) { case "AND": Preconditions.checkState(numOperands >= 2, "AND takes >=2 arguments, got: %s", numOperands); return new FilterOperand.And(operands, dataSchema); @@ -61,26 +59,26 @@ public class TransformOperandFactory { case "NOT_IN": Preconditions.checkState(numOperands >= 2, "NOT_IN takes >=2 arguments, got: %s", numOperands); return new FilterOperand.In(operands, dataSchema, true); - case "ISTRUE": + case "IS_TRUE": Preconditions.checkState(numOperands == 1, "IS_TRUE takes one argument, got: %s", numOperands); return new FilterOperand.IsTrue(operands.get(0), dataSchema); - case "ISNOTTRUE": + case "IS_NOT_TRUE": Preconditions.checkState(numOperands == 1, "IS_NOT_TRUE takes one argument, got: %s", numOperands); return new FilterOperand.IsNotTrue(operands.get(0), dataSchema); - case "equals": + case "EQUALS": return new FilterOperand.Predicate(operands, dataSchema, v -> v == 0); - case "notEquals": + case "NOT_EQUALS": return new FilterOperand.Predicate(operands, dataSchema, v -> v != 0); - case "greaterThan": + case "GREATER_THAN": return new FilterOperand.Predicate(operands, dataSchema, v -> v > 0); - case "greaterThanOrEqual": + case "GREATER_THAN_OR_EQUAL": return new FilterOperand.Predicate(operands, dataSchema, v -> v >= 0); - case "lessThan": + case "LESS_THAN": return new FilterOperand.Predicate(operands, dataSchema, v -> v < 0); - case "lessThanOrEqual": + case "LESS_THAN_OR_EQUAL": return new FilterOperand.Predicate(operands, dataSchema, v -> v <= 0); default: - return new FunctionOperand(functionCall, canonicalName, dataSchema); + return new FunctionOperand(functionCall, dataSchema); } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java index 0ea7b5df87..94148d9aec 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java @@ -214,8 +214,9 @@ public class AggregationUtils { _dataType = inputSchema.getColumnDataType(_inputRef); } else { _inputRef = -1; - _literal = ((RexExpression.Literal) rexExpression).getValue(); - _dataType = rexExpression.getDataType(); + RexExpression.Literal literal = (RexExpression.Literal) rexExpression; + _literal = literal.getValue(); + _dataType = literal.getDataType(); } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java deleted file mode 100644 index a10cde39dc..0000000000 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.query.runtime.operator.utils; - -import java.util.HashMap; -import java.util.Map; -import org.apache.commons.lang.StringUtils; - - -public class OperatorUtils { - private static final Map<String, String> OPERATOR_TOKEN_MAPPING = new HashMap<>(); - - static { - OPERATOR_TOKEN_MAPPING.put("=", "equals"); - OPERATOR_TOKEN_MAPPING.put(">", "greaterThan"); - OPERATOR_TOKEN_MAPPING.put("<", "lessThan"); - OPERATOR_TOKEN_MAPPING.put("<=", "lessThanOrEqual"); - OPERATOR_TOKEN_MAPPING.put(">=", "greaterThanOrEqual"); - OPERATOR_TOKEN_MAPPING.put("<>", "notEquals"); - OPERATOR_TOKEN_MAPPING.put("!=", "notEquals"); - OPERATOR_TOKEN_MAPPING.put("+", "plus"); - OPERATOR_TOKEN_MAPPING.put("-", "minus"); - OPERATOR_TOKEN_MAPPING.put("*", "times"); - OPERATOR_TOKEN_MAPPING.put("/", "divide"); - OPERATOR_TOKEN_MAPPING.put("||", "concat"); - } - - private OperatorUtils() { - // do not instantiate. - } - - /** - * Canonicalize function name since Logical plan uses Parser.jj extracted tokens. - * @param functionName input Function name - * @return Canonicalize form of the input function name - */ - public static String canonicalizeFunctionName(String functionName) { - functionName = StringUtils.remove(functionName, " "); - functionName = OPERATOR_TOKEN_MAPPING.getOrDefault(functionName, functionName); - return functionName; - } -} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java index afab2a8259..6fdde5373a 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.stream.Collectors; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.request.DataSource; +import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.PinotQuery; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.request.RequestUtils; @@ -69,12 +70,11 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); if (pinotQuery.getGroupByList() == null) { - // set group-by list - pinotQuery.setGroupByList(CalciteRexExpressionParser.convertGroupByList(node.getGroupSet(), pinotQuery)); - // set agg list + List<Expression> groupByList = CalciteRexExpressionParser.convertRexNodes(node.getGroupSet(), pinotQuery); + pinotQuery.setGroupByList(groupByList); pinotQuery.setSelectList( - CalciteRexExpressionParser.convertAggregateList(pinotQuery.getGroupByList(), node.getAggCalls(), - node.getFilterArgIndices(), pinotQuery)); + CalciteRexExpressionParser.convertAggregateList(groupByList, node.getAggCalls(), node.getFilterArgIndices(), + pinotQuery)); if (node.getAggType() == AggregateNode.AggType.DIRECT) { pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, "true"); @@ -167,7 +167,7 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla public Void visitProject(ProjectNode node, ServerPlanRequestContext context) { if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); - pinotQuery.setSelectList(CalciteRexExpressionParser.convertProjectList(node.getProjects(), pinotQuery)); + pinotQuery.setSelectList(CalciteRexExpressionParser.convertRexNodes(node.getProjects(), pinotQuery)); } return null; } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index 790bd3a901..70c76c98ce 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -233,8 +233,8 @@ public class AggregateOperatorTest { @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*AVERAGE.*") public void shouldThrowOnUnknownAggFunction() { // Given: - List<RexExpression> calls = ImmutableList.of( - new RexExpression.FunctionCall(SqlKind.AVG, ColumnDataType.INT, "AVERAGE", ImmutableList.of())); + List<RexExpression> calls = + ImmutableList.of(new RexExpression.FunctionCall(ColumnDataType.INT, "AVERAGE", ImmutableList.of())); List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); DataSchema outSchema = new DataSchema(new String[]{"unknown"}, new ColumnDataType[]{DOUBLE}); @@ -299,13 +299,12 @@ public class AggregateOperatorTest { // Then: Assert.assertEquals(block1.getNumRows(), 1, "when group limit reach it should only return that many groups"); Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)"); - StatMap<AggregateOperator.StatKey> aggrStats = - OperatorTestUtil.getStatMap(AggregateOperator.StatKey.class, block2); + StatMap<AggregateOperator.StatKey> aggrStats = OperatorTestUtil.getStatMap(AggregateOperator.StatKey.class, block2); Assert.assertTrue(aggrStats.getBoolean(AggregateOperator.StatKey.NUM_GROUPS_LIMIT_REACHED), "num groups limit should be reached"); } private static RexExpression.FunctionCall getSum(RexExpression arg) { - return new RexExpression.FunctionCall(SqlKind.SUM, ColumnDataType.INT, "SUM", ImmutableList.of(arg)); + return new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.SUM.name(), ImmutableList.of(arg)); } } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java index 27825f916a..1d610aed0a 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java @@ -168,7 +168,7 @@ public class FilterOperatorTest { }); Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{1, 1}, new Object[]{0, 0}, new Object[]{1, 0})); - RexExpression.FunctionCall andCall = new RexExpression.FunctionCall(SqlKind.AND, ColumnDataType.BOOLEAN, "AND", + RexExpression.FunctionCall andCall = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.AND.name(), ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.InputRef(1))); FilterOperator op = @@ -188,7 +188,7 @@ public class FilterOperatorTest { }); Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{1, 1}, new Object[]{0, 0}, new Object[]{1, 0})); - RexExpression.FunctionCall orCall = new RexExpression.FunctionCall(SqlKind.OR, ColumnDataType.BOOLEAN, "OR", + RexExpression.FunctionCall orCall = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.OR.name(), ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.InputRef(1))); FilterOperator op = @@ -210,7 +210,7 @@ public class FilterOperatorTest { }); Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{1, 1}, new Object[]{0, 0}, new Object[]{1, 0})); - RexExpression.FunctionCall notCall = new RexExpression.FunctionCall(SqlKind.NOT, ColumnDataType.BOOLEAN, "NOT", + RexExpression.FunctionCall notCall = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.NOT.name(), ImmutableList.of(new RexExpression.InputRef(0))); FilterOperator op = @@ -231,7 +231,7 @@ public class FilterOperatorTest { Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{1, 2}, new Object[]{3, 2}, new Object[]{1, 1})); RexExpression.FunctionCall greaterThan = - new RexExpression.FunctionCall(SqlKind.GREATER_THAN, ColumnDataType.BOOLEAN, "greaterThan", + new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.GREATER_THAN.name(), ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.InputRef(1))); FilterOperator op = new FilterOperator(OperatorTestUtil.getTracingContext(), _upstreamOperator, inputSchema, greaterThan); @@ -251,7 +251,7 @@ public class FilterOperatorTest { Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{"starTree"}, new Object[]{"treeStar"})); RexExpression.FunctionCall startsWith = - new RexExpression.FunctionCall(SqlKind.OTHER, ColumnDataType.BOOLEAN, "startsWith", + new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.STARTS_WITH.name(), ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.STRING, "star"))); FilterOperator op = new FilterOperator(OperatorTestUtil.getTracingContext(), _upstreamOperator, inputSchema, startsWith); @@ -271,9 +271,8 @@ public class FilterOperatorTest { }); Mockito.when(_upstreamOperator.nextBlock()) .thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{"starTree"}, new Object[]{"treeStar"})); - RexExpression.FunctionCall startsWith = - new RexExpression.FunctionCall(SqlKind.OTHER, ColumnDataType.BOOLEAN, "startsWithError", - ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.STRING, "star"))); + RexExpression.FunctionCall startsWith = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, "startsWithError", + ImmutableList.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.STRING, "star"))); new FilterOperator(OperatorTestUtil.getTracingContext(), _upstreamOperator, inputSchema, startsWith); } } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java index 45afa6dbc6..91d4dee224 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java @@ -325,7 +325,8 @@ public class HashJoinOperatorTest { List<RexExpression> functionOperands = new ArrayList<>(); functionOperands.add(new RexExpression.InputRef(1)); functionOperands.add(new RexExpression.InputRef(3)); - joinClauses.add(new RexExpression.FunctionCall(SqlKind.NOT_EQUALS, ColumnDataType.BOOLEAN, "<>", functionOperands)); + joinClauses.add( + new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.NOT_EQUALS.name(), functionOperands)); DataSchema resultSchema = new DataSchema(new String[]{"int_col1", "string_col1", "int_col2", "string_col2"}, new ColumnDataType[]{ ColumnDataType.INT, ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.STRING @@ -363,7 +364,8 @@ public class HashJoinOperatorTest { List<RexExpression> functionOperands = new ArrayList<>(); functionOperands.add(new RexExpression.InputRef(0)); functionOperands.add(new RexExpression.InputRef(2)); - joinClauses.add(new RexExpression.FunctionCall(SqlKind.NOT_EQUALS, ColumnDataType.BOOLEAN, "<>", functionOperands)); + joinClauses.add( + new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.NOT_EQUALS.name(), functionOperands)); DataSchema resultSchema = new DataSchema(new String[]{"int_col1", "string_col1", "int_co2", "string_col2"}, new ColumnDataType[]{ ColumnDataType.INT, ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.STRING @@ -658,7 +660,7 @@ public class HashJoinOperatorTest { TransferableBlock secondBlock = join.nextBlock(); StatMap<HashJoinOperator.StatKey> joinStats = OperatorTestUtil.getStatMap(HashJoinOperator.StatKey.class, secondBlock); - Assert.assertTrue(joinStats.getBoolean(HashJoinOperator.StatKey.MAX_ROWS_IN_JOIN_REACHED), + Assert.assertTrue(joinStats.getBoolean(HashJoinOperator.StatKey.MAX_ROWS_IN_JOIN_REACHED), "Max rows in join should be reached"); } } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java index 123cda06a5..723ffef562 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.calcite.sql.SqlKind; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -36,9 +37,6 @@ import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import static org.apache.calcite.sql.SqlKind.MINUS; -import static org.apache.calcite.sql.SqlKind.PLUS; - public class TransformOperatorTest { private AutoCloseable _mocks; @@ -117,9 +115,9 @@ public class TransformOperatorTest { RexExpression.InputRef ref1 = new RexExpression.InputRef(1); List<RexExpression> functionOperands = ImmutableList.of(ref0, ref1); RexExpression.FunctionCall plus01 = - new RexExpression.FunctionCall(PLUS, ColumnDataType.DOUBLE, "plus", functionOperands); + new RexExpression.FunctionCall(ColumnDataType.DOUBLE, SqlKind.PLUS.name(), functionOperands); RexExpression.FunctionCall minus01 = - new RexExpression.FunctionCall(MINUS, ColumnDataType.DOUBLE, "minus", functionOperands); + new RexExpression.FunctionCall(ColumnDataType.DOUBLE, SqlKind.MINUS.name(), functionOperands); DataSchema resultSchema = new DataSchema(new String[]{"plusR", "minusR"}, new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.DOUBLE}); TransformOperator op = new TransformOperator(OperatorTestUtil.getTracingContext(), _upstreamOp, resultSchema, @@ -145,9 +143,9 @@ public class TransformOperatorTest { RexExpression.InputRef ref1 = new RexExpression.InputRef(1); List<RexExpression> functionOperands = ImmutableList.of(ref0, ref1); RexExpression.FunctionCall plus01 = - new RexExpression.FunctionCall(PLUS, ColumnDataType.DOUBLE, "plus", functionOperands); + new RexExpression.FunctionCall(ColumnDataType.DOUBLE, SqlKind.PLUS.name(), functionOperands); RexExpression.FunctionCall minus01 = - new RexExpression.FunctionCall(MINUS, ColumnDataType.DOUBLE, "minus", functionOperands); + new RexExpression.FunctionCall(ColumnDataType.DOUBLE, SqlKind.MINUS.name(), functionOperands); DataSchema resultSchema = new DataSchema(new String[]{"plusR", "minusR"}, new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.DOUBLE}); TransformOperator op = new TransformOperator(OperatorTestUtil.getTracingContext(), _upstreamOp, resultSchema, diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index 56761b2294..70dcbc836f 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -270,8 +270,8 @@ public class WindowAggregateOperatorTest { + "WindowFunction for function name: AVERAGE.*") public void testShouldThrowOnUnknownAggFunction() { // Given: - List<RexExpression> calls = ImmutableList.of( - new RexExpression.FunctionCall(SqlKind.AVG, ColumnDataType.INT, "AVERAGE", ImmutableList.of())); + List<RexExpression> calls = + ImmutableList.of(new RexExpression.FunctionCall(ColumnDataType.INT, "AVERAGE", ImmutableList.of())); List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); DataSchema outSchema = new DataSchema(new String[]{"unknown"}, new ColumnDataType[]{DOUBLE}); DataSchema inSchema = new DataSchema(new String[]{"unknown"}, new ColumnDataType[]{DOUBLE}); @@ -290,7 +290,7 @@ public class WindowAggregateOperatorTest { // TODO: Remove this test when support is added for NTILE function // Given: List<RexExpression> calls = - ImmutableList.of(new RexExpression.FunctionCall(SqlKind.RANK, ColumnDataType.INT, "NTILE", ImmutableList.of())); + ImmutableList.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.NTILE.name(), ImmutableList.of())); List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); DataSchema outSchema = new DataSchema(new String[]{"unknown"}, new ColumnDataType[]{DOUBLE}); DataSchema inSchema = new DataSchema(new String[]{"unknown"}, new ColumnDataType[]{DOUBLE}); @@ -308,8 +308,8 @@ public class WindowAggregateOperatorTest { throws ProcessingException { // Given: List<RexExpression> calls = - ImmutableList.of(new RexExpression.FunctionCall(SqlKind.RANK, ColumnDataType.INT, "RANK", ImmutableList.of()), - new RexExpression.FunctionCall(SqlKind.DENSE_RANK, ColumnDataType.INT, "DENSE_RANK", ImmutableList.of())); + ImmutableList.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.RANK.name(), ImmutableList.of()), + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.DENSE_RANK.name(), ImmutableList.of())); List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); List<RexExpression> order = ImmutableList.of(new RexExpression.InputRef(1)); @@ -368,7 +368,7 @@ public class WindowAggregateOperatorTest { throws ProcessingException { // Given: List<RexExpression> calls = ImmutableList.of( - new RexExpression.FunctionCall(SqlKind.ROW_NUMBER, ColumnDataType.INT, "ROW_NUMBER", ImmutableList.of())); + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.ROW_NUMBER.name(), ImmutableList.of())); List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0)); List<RexExpression> order = ImmutableList.of(new RexExpression.InputRef(1)); @@ -533,8 +533,8 @@ public class WindowAggregateOperatorTest { WindowAggregateOperator operator = new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), _input, group, order, Arrays.asList(RelFieldCollation.Direction.ASCENDING), Arrays.asList(RelFieldCollation.NullDirection.LAST), - calls, Integer.MIN_VALUE, 0, WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, - inSchema, getWindowHints(ImmutableMap.of())); + calls, Integer.MIN_VALUE, 0, WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, + getWindowHints(ImmutableMap.of())); // When: TransferableBlock block1 = operator.nextBlock(); @@ -636,8 +636,7 @@ public class WindowAggregateOperatorTest { WindowAggregateOperator operator = new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), _input, group, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, Integer.MAX_VALUE, - WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, - getWindowHints(hintsMap)); + WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, getWindowHints(hintsMap)); // When: TransferableBlock block = operator.nextBlock(); @@ -665,8 +664,7 @@ public class WindowAggregateOperatorTest { WindowAggregateOperator operator = new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), _input, group, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), calls, Integer.MIN_VALUE, Integer.MAX_VALUE, - WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, - getWindowHints(hintsMap)); + WindowNode.WindowFrameType.RANGE, Collections.emptyList(), outSchema, inSchema, getWindowHints(hintsMap)); // When: TransferableBlock firstBlock = operator.nextBlock(); @@ -682,7 +680,7 @@ public class WindowAggregateOperatorTest { } private static RexExpression.FunctionCall getSum(RexExpression arg) { - return new RexExpression.FunctionCall(SqlKind.SUM, ColumnDataType.INT, "SUM", ImmutableList.of(arg)); + return new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.SUM.name(), ImmutableList.of(arg)); } private static AbstractPlanNode.NodeHint getWindowHints(Map<String, String> hintsMap) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org