This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new f8de958174 Use argument type to lookup function for literal only query
(#13673)
f8de958174 is described below
commit f8de958174e209ab7f11572149ce1723d63b5af3
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Aug 28 11:50:02 2024 -0700
Use argument type to lookup function for literal only query (#13673)
---
.../BaseSingleStageBrokerRequestHandler.java | 120 ++++++--------------
.../common/request/context/LiteralContext.java | 36 +-----
.../pinot/common/utils/request/RequestUtils.java | 98 ++++++++++++++--
.../rewriter/CompileTimeFunctionsInvoker.java | 66 ++++++-----
.../pinot/sql/parsers/CalciteSqlCompilerTest.java | 125 ++++++++++-----------
5 files changed, 225 insertions(+), 220 deletions(-)
diff --git
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
index 28eebf205b..83f83188dc 100644
---
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
+++
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
@@ -77,6 +77,7 @@ import org.apache.pinot.common.response.ProcessingException;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.common.response.broker.ResultTable;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.common.utils.DatabaseUtils;
import org.apache.pinot.common.utils.config.QueryOptionsUtils;
import org.apache.pinot.common.utils.request.RequestUtils;
@@ -101,8 +102,6 @@ import
org.apache.pinot.spi.exception.BadQueryRequestException;
import org.apache.pinot.spi.exception.DatabaseConflictException;
import org.apache.pinot.spi.trace.RequestContext;
import org.apache.pinot.spi.trace.Tracing;
-import org.apache.pinot.spi.utils.BigDecimalUtils;
-import org.apache.pinot.spi.utils.BytesUtils;
import org.apache.pinot.spi.utils.CommonConstants;
import org.apache.pinot.spi.utils.CommonConstants.Broker;
import
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
@@ -1489,16 +1488,18 @@ public abstract class
BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ
private BrokerResponseNative processLiteralOnlyQuery(long requestId,
PinotQuery pinotQuery,
RequestContext requestContext) {
BrokerResponseNative brokerResponse = new BrokerResponseNative();
- List<String> columnNames = new ArrayList<>();
- List<DataSchema.ColumnDataType> columnTypes = new ArrayList<>();
- List<Object> row = new ArrayList<>();
- for (Expression expression : pinotQuery.getSelectList()) {
- computeResultsForExpression(expression, columnNames, columnTypes, row);
- }
- DataSchema dataSchema =
- new DataSchema(columnNames.toArray(new String[0]),
columnTypes.toArray(new DataSchema.ColumnDataType[0]));
- List<Object[]> rows = new ArrayList<>();
- rows.add(row.toArray());
+ List<Expression> selectList = pinotQuery.getSelectList();
+ int numColumns = selectList.size();
+ String[] columnNames = new String[numColumns];
+ ColumnDataType[] columnTypes = new ColumnDataType[numColumns];
+ Object[] values = new Object[numColumns];
+ for (int i = 0; i < numColumns; i++) {
+ computeResultsForExpression(selectList.get(i), columnNames, columnTypes,
values, i);
+ values[i] = columnTypes[i].format(values[i]);
+ }
+ DataSchema dataSchema = new DataSchema(columnNames, columnTypes);
+ List<Object[]> rows = new ArrayList<>(1);
+ rows.add(values);
ResultTable resultTable = new ResultTable(dataSchema, rows);
brokerResponse.setResultTable(resultTable);
brokerResponse.setTimeUsedMs(System.currentTimeMillis() -
requestContext.getRequestArrivalTimeMillis());
@@ -1510,87 +1511,30 @@ public abstract class
BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ
}
// TODO(xiangfu): Move Literal function computation here from Calcite Parser.
- private void computeResultsForExpression(Expression e, List<String>
columnNames,
- List<DataSchema.ColumnDataType> columnTypes, List<Object> row) {
- if (e.getType() == ExpressionType.LITERAL) {
- computeResultsForLiteral(e.getLiteral(), columnNames, columnTypes, row);
- }
- if (e.getType() == ExpressionType.FUNCTION) {
- if (e.getFunctionCall().getOperator().equals("as")) {
- String columnName =
e.getFunctionCall().getOperands().get(1).getIdentifier().getName();
- computeResultsForExpression(e.getFunctionCall().getOperands().get(0),
columnNames, columnTypes, row);
- columnNames.set(columnNames.size() - 1, columnName);
+ private void computeResultsForExpression(Expression expression, String[]
columnNames, ColumnDataType[] columnTypes,
+ Object[] values, int index) {
+ ExpressionType type = expression.getType();
+ if (type == ExpressionType.LITERAL) {
+ computeResultsForLiteral(expression.getLiteral(), columnNames,
columnTypes, values, index);
+ } else if (type == ExpressionType.FUNCTION) {
+ Function function = expression.getFunctionCall();
+ String operator = function.getOperator();
+ if (operator.equals("as")) {
+ List<Expression> operands = function.getOperands();
+ computeResultsForExpression(operands.get(0), columnNames, columnTypes,
values, index);
+ columnNames[index] = operands.get(1).getIdentifier().getName();
} else {
- throw new IllegalStateException(
- "No able to compute results for function - " +
e.getFunctionCall().getOperator());
+ throw new IllegalStateException("No able to compute results for
function - " + operator);
}
}
}
- private void computeResultsForLiteral(Literal literal, List<String>
columnNames,
- List<DataSchema.ColumnDataType> columnTypes, List<Object> row) {
- columnNames.add(RequestUtils.prettyPrint(literal));
- switch (literal.getSetField()) {
- case NULL_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.UNKNOWN);
- row.add(null);
- break;
- case BOOL_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.BOOLEAN);
- row.add(literal.getBoolValue());
- break;
- case INT_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.INT);
- row.add(literal.getIntValue());
- break;
- case LONG_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.LONG);
- row.add(literal.getLongValue());
- break;
- case FLOAT_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.FLOAT);
- row.add(Float.intBitsToFloat(literal.getFloatValue()));
- break;
- case DOUBLE_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.DOUBLE);
- row.add(literal.getDoubleValue());
- break;
- case BIG_DECIMAL_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.BIG_DECIMAL);
- row.add(BigDecimalUtils.deserialize(literal.getBigDecimalValue()));
- break;
- case STRING_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.STRING);
- row.add(literal.getStringValue());
- break;
- case BINARY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.BYTES);
- row.add(BytesUtils.toHexString(literal.getBinaryValue()));
- break;
- // TODO: Revisit the array handling. Currently we are setting List into
the row.
- case INT_ARRAY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.INT_ARRAY);
- row.add(literal.getIntArrayValue());
- break;
- case LONG_ARRAY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.LONG_ARRAY);
- row.add(literal.getLongArrayValue());
- break;
- case FLOAT_ARRAY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.FLOAT_ARRAY);
-
row.add(literal.getFloatArrayValue().stream().map(Float::intBitsToFloat).collect(Collectors.toList()));
- break;
- case DOUBLE_ARRAY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.DOUBLE_ARRAY);
- row.add(literal.getDoubleArrayValue());
- break;
- case STRING_ARRAY_VALUE:
- columnTypes.add(DataSchema.ColumnDataType.STRING_ARRAY);
- row.add(literal.getStringArrayValue());
- break;
- default:
- throw new IllegalStateException("Unsupported literal: " + literal);
- }
+ private void computeResultsForLiteral(Literal literal, String[] columnNames,
ColumnDataType[] columnTypes,
+ Object[] values, int index) {
+ columnNames[index] = RequestUtils.prettyPrint(literal);
+ Pair<ColumnDataType, Object> typeAndValue =
RequestUtils.getLiteralTypeAndValue(literal);
+ columnTypes[index] = typeAndValue.getLeft();
+ values[index] = typeAndValue.getRight();
}
/**
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
index 0a2b8ad6e1..c0a55f23f5 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
@@ -23,11 +23,11 @@ import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.Arrays;
-import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.utils.PinotDataType;
+import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
@@ -105,55 +105,31 @@ public class LiteralContext {
break;
case INT_ARRAY_VALUE: {
_type = DataType.INT;
- List<Integer> valueList = literal.getIntArrayValue();
- int numValues = valueList.size();
- int[] values = new int[numValues];
- for (int i = 0; i < numValues; i++) {
- values[i] = valueList.get(i);
- }
- _value = values;
+ _value = RequestUtils.getIntArrayValue(literal);
_pinotDataType = PinotDataType.PRIMITIVE_INT_ARRAY;
break;
}
case LONG_ARRAY_VALUE: {
_type = DataType.LONG;
- List<Long> valueList = literal.getLongArrayValue();
- int numValues = valueList.size();
- long[] values = new long[numValues];
- for (int i = 0; i < numValues; i++) {
- values[i] = valueList.get(i);
- }
- _value = values;
+ _value = RequestUtils.getLongArrayValue(literal);
_pinotDataType = PinotDataType.PRIMITIVE_LONG_ARRAY;
break;
}
case FLOAT_ARRAY_VALUE: {
_type = DataType.FLOAT;
- List<Integer> valueList = literal.getFloatArrayValue();
- int numValues = valueList.size();
- float[] values = new float[numValues];
- for (int i = 0; i < numValues; i++) {
- values[i] = Float.intBitsToFloat(valueList.get(i));
- }
- _value = values;
+ _value = RequestUtils.getFloatArrayValue(literal);
_pinotDataType = PinotDataType.PRIMITIVE_FLOAT_ARRAY;
break;
}
case DOUBLE_ARRAY_VALUE: {
_type = DataType.DOUBLE;
- List<Double> valueList = literal.getDoubleArrayValue();
- int numValues = valueList.size();
- double[] values = new double[numValues];
- for (int i = 0; i < numValues; i++) {
- values[i] = valueList.get(i);
- }
- _value = values;
+ _value = RequestUtils.getDoubleArrayValue(literal);
_pinotDataType = PinotDataType.PRIMITIVE_DOUBLE_ARRAY;
break;
}
case STRING_ARRAY_VALUE:
_type = DataType.STRING;
- _value = literal.getStringArrayValue().toArray(new String[0]);
+ _value = RequestUtils.getStringArrayValue(literal);
_pinotDataType = PinotDataType.STRING_ARRAY;
break;
default:
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
index e8feaeeb07..6147b7f7ea 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
@@ -41,6 +41,7 @@ import javax.annotation.Nullable;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.DataSource;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
@@ -48,6 +49,7 @@ import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.Identifier;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.apache.pinot.spi.utils.BytesUtils;
import org.apache.pinot.spi.utils.CommonConstants.Broker.Request;
@@ -343,21 +345,95 @@ public class RequestUtils {
case BINARY_VALUE:
return literal.getBinaryValue();
case INT_ARRAY_VALUE:
- return
literal.getIntArrayValue().stream().mapToInt(Integer::intValue).toArray();
+ return getIntArrayValue(literal);
case LONG_ARRAY_VALUE:
- return
literal.getLongArrayValue().stream().mapToLong(Long::longValue).toArray();
+ return getLongArrayValue(literal);
case FLOAT_ARRAY_VALUE:
- List<Integer> floatList = literal.getFloatArrayValue();
- int numFloats = floatList.size();
- float[] floatArray = new float[numFloats];
- for (int i = 0; i < numFloats; i++) {
- floatArray[i] = Float.intBitsToFloat(floatList.get(i));
- }
- return floatArray;
+ return getFloatArrayValue(literal);
+ case DOUBLE_ARRAY_VALUE:
+ return getDoubleArrayValue(literal);
+ case STRING_ARRAY_VALUE:
+ return getStringArrayValue(literal);
+ default:
+ throw new IllegalStateException("Unsupported field type: " + type);
+ }
+ }
+
+ public static int[] getIntArrayValue(Literal literal) {
+ List<Integer> list = literal.getIntArrayValue();
+ int size = list.size();
+ int[] array = new int[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = list.get(i);
+ }
+ return array;
+ }
+
+ public static long[] getLongArrayValue(Literal literal) {
+ List<Long> list = literal.getLongArrayValue();
+ int size = list.size();
+ long[] array = new long[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = list.get(i);
+ }
+ return array;
+ }
+
+ public static float[] getFloatArrayValue(Literal literal) {
+ List<Integer> list = literal.getFloatArrayValue();
+ int size = list.size();
+ float[] array = new float[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = Float.intBitsToFloat(list.get(i));
+ }
+ return array;
+ }
+
+ public static double[] getDoubleArrayValue(Literal literal) {
+ List<Double> list = literal.getDoubleArrayValue();
+ int size = list.size();
+ double[] array = new double[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = list.get(i);
+ }
+ return array;
+ }
+
+ public static String[] getStringArrayValue(Literal literal) {
+ return literal.getStringArrayValue().toArray(new String[0]);
+ }
+
+ public static Pair<ColumnDataType, Object> getLiteralTypeAndValue(Literal
literal) {
+ Literal._Fields type = literal.getSetField();
+ switch (type) {
+ case NULL_VALUE:
+ return Pair.of(ColumnDataType.UNKNOWN, null);
+ case BOOL_VALUE:
+ return Pair.of(ColumnDataType.BOOLEAN, literal.getBoolValue());
+ case INT_VALUE:
+ return Pair.of(ColumnDataType.INT, literal.getIntValue());
+ case LONG_VALUE:
+ return Pair.of(ColumnDataType.LONG, literal.getLongValue());
+ case FLOAT_VALUE:
+ return Pair.of(ColumnDataType.FLOAT,
Float.intBitsToFloat(literal.getFloatValue()));
+ case DOUBLE_VALUE:
+ return Pair.of(ColumnDataType.DOUBLE, literal.getDoubleValue());
+ case BIG_DECIMAL_VALUE:
+ return Pair.of(ColumnDataType.BIG_DECIMAL,
BigDecimalUtils.deserialize(literal.getBigDecimalValue()));
+ case STRING_VALUE:
+ return Pair.of(ColumnDataType.STRING, literal.getStringValue());
+ case BINARY_VALUE:
+ return Pair.of(ColumnDataType.BYTES, literal.getBinaryValue());
+ case INT_ARRAY_VALUE:
+ return Pair.of(ColumnDataType.INT_ARRAY, getIntArrayValue(literal));
+ case LONG_ARRAY_VALUE:
+ return Pair.of(ColumnDataType.LONG_ARRAY, getLongArrayValue(literal));
+ case FLOAT_ARRAY_VALUE:
+ return Pair.of(ColumnDataType.FLOAT_ARRAY,
getFloatArrayValue(literal));
case DOUBLE_ARRAY_VALUE:
- return
literal.getDoubleArrayValue().stream().mapToDouble(Double::doubleValue).toArray();
+ return Pair.of(ColumnDataType.DOUBLE_ARRAY,
getDoubleArrayValue(literal));
case STRING_ARRAY_VALUE:
- return literal.getStringArrayValue().toArray(new String[0]);
+ return Pair.of(ColumnDataType.STRING_ARRAY,
getStringArrayValue(literal));
default:
throw new IllegalStateException("Unsupported field type: " + type);
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
index 6a47fa827e..1e10fbed52 100644
---
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
+++
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
@@ -18,20 +18,25 @@
*/
package org.apache.pinot.sql.parsers.rewriter;
+import com.google.common.annotations.VisibleForTesting;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nullable;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
+import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.sql.parsers.SqlCompilationException;
public class CompileTimeFunctionsInvoker implements QueryRewriter {
+
@Override
public PinotQuery rewrite(PinotQuery pinotQuery) {
for (int i = 0; i < pinotQuery.getSelectListSize(); i++) {
@@ -53,7 +58,8 @@ public class CompileTimeFunctionsInvoker implements
QueryRewriter {
return pinotQuery;
}
- protected static Expression invokeCompileTimeFunctionExpression(@Nullable
Expression expression) {
+ @VisibleForTesting
+ public static Expression invokeCompileTimeFunctionExpression(@Nullable
Expression expression) {
if (expression == null || expression.getFunctionCall() == null) {
return expression;
}
@@ -61,38 +67,44 @@ public class CompileTimeFunctionsInvoker implements
QueryRewriter {
List<Expression> operands = function.getOperands();
int numOperands = operands.size();
boolean compilable = true;
+ ColumnDataType[] argumentTypes = new ColumnDataType[numOperands];
+ Object[] arguments = new Object[numOperands];
for (int i = 0; i < numOperands; i++) {
Expression operand =
invokeCompileTimeFunctionExpression(operands.get(i));
- if (operand.getLiteral() == null) {
+ operands.set(i, operand);
+ Literal literal = operand.getLiteral();
+ if (compilable && literal != null) {
+ Pair<ColumnDataType, Object> typeAndValue =
RequestUtils.getLiteralTypeAndValue(literal);
+ argumentTypes[i] = typeAndValue.getLeft();
+ arguments[i] = typeAndValue.getRight();
+ } else {
+ // NOTE: Do not directly 'return expression;' here because we want to
compile all operands even if the current
+ // expression is not compilable.
compilable = false;
}
- operands.set(i, operand);
}
- if (compilable) {
- String canonicalName =
FunctionRegistry.canonicalize(function.getOperator());
- FunctionInfo functionInfo =
FunctionRegistry.lookupFunctionInfo(canonicalName, numOperands);
- if (functionInfo != null) {
- Object[] arguments = new Object[numOperands];
- for (int i = 0; i < numOperands; i++) {
- arguments[i] =
RequestUtils.getLiteralValue(function.getOperands().get(i).getLiteral());
- }
- try {
- FunctionInvoker invoker = new FunctionInvoker(functionInfo);
- Object result;
- if (invoker.getMethod().isVarArgs()) {
- result = invoker.invoke(new Object[] {arguments});
- } else {
- invoker.convertTypes(arguments);
- result = invoker.invoke(arguments);
- }
- return RequestUtils.getLiteralExpression(result);
- } catch (Exception e) {
- throw new SqlCompilationException(
- "Caught exception while invoking method: " +
functionInfo.getMethod() + " with arguments: "
- + Arrays.toString(arguments), e);
- }
+ if (!compilable) {
+ return expression;
+ }
+ String canonicalName =
FunctionRegistry.canonicalize(function.getOperator());
+ FunctionInfo functionInfo =
FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes);
+ if (functionInfo == null) {
+ return expression;
+ }
+ try {
+ FunctionInvoker invoker = new FunctionInvoker(functionInfo);
+ Object result;
+ if (invoker.getMethod().isVarArgs()) {
+ result = invoker.invoke(new Object[]{arguments});
+ } else {
+ invoker.convertTypes(arguments);
+ result = invoker.invoke(arguments);
}
+ return RequestUtils.getLiteralExpression(result);
+ } catch (Exception e) {
+ throw new SqlCompilationException(
+ "Caught exception while invoking method: " +
functionInfo.getMethod() + " with arguments: " + Arrays.toString(
+ arguments), e);
}
- return expression;
}
}
diff --git
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 369dd8b886..35a625505a 100644
---
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -1054,8 +1054,7 @@ public class CalciteSqlCompilerTest {
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(2)
.getLiteral().getStringValue(), "SECONDS");
Assert.assertEquals(
-
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(),
- 1394323200);
+
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(),
1394323200);
}
@Test
@@ -1379,8 +1378,8 @@ public class CalciteSqlCompilerTest {
Assert.fail("Query should have failed compilation");
} catch (Exception e) {
Assert.assertTrue(e instanceof SqlCompilationException);
- Assert.assertTrue(e.getMessage().contains("'group_city' should be
functionally dependent on the columns "
- + "used in GROUP BY clause."));
+ Assert.assertTrue(e.getMessage()
+ .contains("'group_city' should be functionally dependent on the
columns " + "used in GROUP BY clause."));
}
// Valid groupBy non-aggregate function should pass.
@@ -1398,8 +1397,8 @@ public class CalciteSqlCompilerTest {
Assert.fail("Query should have failed compilation");
} catch (Exception e) {
Assert.assertTrue(e instanceof SqlCompilationException);
- Assert.assertTrue(e.getMessage().contains("'secondsSinceEpoch' should be
functionally dependent on the columns "
- + "used in GROUP BY clause."));
+ Assert.assertTrue(e.getMessage().contains(
+ "'secondsSinceEpoch' should be functionally dependent on the columns
" + "used in GROUP BY clause."));
}
// Invalid groupBy clause shouldn't contain aggregate expression, like
sum(rsvp_count), count(*).
@@ -2331,14 +2330,10 @@ public class CalciteSqlCompilerTest {
@Test
public void testCompileTimeExpression() {
- final CompileTimeFunctionsInvoker compileTimeFunctionsInvoker = new
CompileTimeFunctionsInvoker();
long lowerBound = System.currentTimeMillis();
Expression expression = compileToExpression("now()");
Assert.assertNotNull(expression.getFunctionCall());
- PinotQuery pinotQuery = new PinotQuery();
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
long upperBound = System.currentTimeMillis();
long result = expression.getLiteral().getLongValue();
@@ -2347,10 +2342,7 @@ public class CalciteSqlCompilerTest {
lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
expression = compileToExpression("to_epoch_hours(now() + 3600000)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
- Assert.assertNotNull(expression.getLiteral());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
upperBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
result = expression.getLiteral().getLongValue();
Assert.assertTrue(result >= lowerBound && result <= upperBound);
@@ -2358,9 +2350,7 @@ public class CalciteSqlCompilerTest {
lowerBound = System.currentTimeMillis() - ONE_HOUR_IN_MS;
expression = compileToExpression("ago('PT1H')");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
upperBound = System.currentTimeMillis() - ONE_HOUR_IN_MS;
result = expression.getLiteral().getLongValue();
@@ -2369,9 +2359,7 @@ public class CalciteSqlCompilerTest {
lowerBound = System.currentTimeMillis() + ONE_HOUR_IN_MS;
expression = compileToExpression("ago('PT-1H')");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
upperBound = System.currentTimeMillis() + ONE_HOUR_IN_MS;
result = expression.getLiteral().getLongValue();
@@ -2379,9 +2367,7 @@ public class CalciteSqlCompilerTest {
expression = compileToExpression("toDateTime(millisSinceEpoch)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(),
"todatetime");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
@@ -2389,88 +2375,105 @@ public class CalciteSqlCompilerTest {
expression = compileToExpression("encodeUrl('key1=value
1&key2=value@!$2&key3=value%3')");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(),
"key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253");
expression =
compileToExpression("decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253')");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(), "key1=value
1&key2=value@!$2&key3=value%3");
expression = compileToExpression("reverse(playerName)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(), "reverse");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"playerName");
expression = compileToExpression("reverse('playerName')");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(),
"emaNreyalp");
expression = compileToExpression("reverse(123)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(), "321");
expression = compileToExpression("count(*)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(), "count");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"*");
expression = compileToExpression("toBase64(toUtf8('hello!'))");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(), "aGVsbG8h");
expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h'))");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getStringValue(), "hello!");
expression = compileToExpression("fromBase64(foo)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(),
"frombase64");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"foo");
expression = compileToExpression("toBase64(foo)");
Assert.assertNotNull(expression.getFunctionCall());
- pinotQuery.setFilterExpression(expression);
- pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
- expression = pinotQuery.getFilterExpression();
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(),
"tobase64");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"foo");
+
+ expression = compileToExpression("'foo' > 'bar'");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+ expression = compileToExpression("toBase64(toUtf8('hello!')) =
'aGVsbG8h'");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+ expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h')) !=
'hello!'");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertFalse(expression.getLiteral().getBoolValue());
+
+ expression = compileToExpression("123 < 123.000000000000000000001");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertFalse(expression.getLiteral().getBoolValue());
+
+ expression = compileToExpression("cast('123' as big_decimal) <
cast('123.000000000000000000001' as big_decimal)");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+ // Should fall back to DOUBLE comparison
+ expression = compileToExpression("123 < cast('123.000000000000000000001'
as big_decimal)");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertFalse(expression.getLiteral().getBoolValue());
}
@Test
@@ -2599,19 +2602,14 @@ public class CalciteSqlCompilerTest {
String query = "SELECT col1 FROM foo GROUP BY col1, col2";
PinotQuery pinotQuery = compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
- Assert.assertEquals(
- pinotQuery.getSelectList().get(0).getIdentifier().getName(), "col1");
- Assert.assertEquals(
- pinotQuery.getGroupByList().get(0).getIdentifier().getName(), "col1");
- Assert.assertEquals(
- pinotQuery.getGroupByList().get(1).getIdentifier().getName(), "col2");
+
Assert.assertEquals(pinotQuery.getSelectList().get(0).getIdentifier().getName(),
"col1");
+
Assert.assertEquals(pinotQuery.getGroupByList().get(0).getIdentifier().getName(),
"col1");
+
Assert.assertEquals(pinotQuery.getGroupByList().get(1).getIdentifier().getName(),
"col2");
query = "SELECT col1+col2 FROM foo GROUP BY col1,col2";
pinotQuery = compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
- Assert.assertEquals(
- pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(),
- "plus");
+
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(),
"plus");
Assert.assertEquals(
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"col1");
Assert.assertEquals(
@@ -3023,7 +3021,6 @@ public class CalciteSqlCompilerTest {
public void testParserExtensionImpl() {
String customSql = "INSERT INTO db.tbl FROM FILE 'file:///tmp/file1', FILE
'file:///tmp/file2'";
SqlNodeAndOptions sqlNodeAndOptions =
CalciteSqlParser.compileToSqlNodeAndOptions(customSql);
- ;
Assert.assertTrue(sqlNodeAndOptions.getSqlNode() instanceof
SqlInsertFromFile);
Assert.assertEquals(sqlNodeAndOptions.getSqlType(), PinotSqlType.DML);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]