This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 2cfaed3 Support type conversion for all scalar functions (#5849)
2cfaed3 is described below
commit 2cfaed37cf581362b87a36e924cdd5744d430e03
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Aug 12 18:15:14 2020 -0700
Support type conversion for all scalar functions (#5849)
Parameter classes supported for type conversion using `PinotDataType`:
- int/Integer
- long/Long
- float/Float
- double/Double
- String
- byte[]
Also handle function name with underscore in `FunctionRegistry`.
Support type conversion for all features using the scalar function:
- Compile time function in `CalciteSqlParser`
- Record transform/filter during ingestion in `InbuiltFunctionEvaluator`
- Transform during query execution in `ScalarTransformFunctionWrapper`
Add `PostAggregationFunction` to handle post-aggregation calculation using
the scalar function.
Add `ArithmeticFunctions` for all the arithmetic scalar functions:
- plus
- minus
- times
- divide
- mod
- min
- max
- abs
- ceil
- floor
- exp
- ln
- sqrt
---
.../apache/pinot/common/function/FunctionInfo.java | 44 +--
.../pinot/common/function/FunctionInvoker.java | 137 ++++++---
.../pinot/common/function/FunctionRegistry.java | 3 +-
.../pinot/common/function/FunctionUtils.java | 118 ++++++++
.../function/annotations/ScalarFunction.java | 23 +-
.../function/scalar/ArithmeticFunctions.java | 95 ++++++
.../function/{ => scalar}/DateTimeFunctions.java | 89 +++---
.../function/{ => scalar}/JsonFunctions.java | 9 +-
.../function/{ => scalar}/StringFunctions.java | 40 +--
.../apache/pinot/common/utils}/PinotDataType.java | 32 +-
.../apache/pinot/sql/parsers/CalciteSqlParser.java | 3 +-
.../pinot/common/utils}/PinotDataTypeTest.java | 4 +-
.../pinot/sql/parsers/CalciteSqlCompilerTest.java | 57 ++--
.../data/function/FunctionEvaluatorFactory.java | 15 +-
.../data/function/InbuiltFunctionEvaluator.java | 75 ++---
.../recordtransformer/DataTypeTransformer.java | 1 +
.../function/ScalarTransformFunctionWrapper.java | 331 +++++++++------------
.../function/TransformFunctionFactory.java | 13 +-
.../postaggregation/PostAggregationFunction.java | 80 +++++
.../core/data/function/InbuiltFunctionsTest.java | 83 ++++--
.../ScalarTransformFunctionWrapperTest.java | 8 +-
.../PostAggregationFunctionTest.java | 62 ++++
22 files changed, 845 insertions(+), 477 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInfo.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInfo.java
index 5cb81eb..0169823 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInfo.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInfo.java
@@ -22,14 +22,13 @@ import java.lang.reflect.Method;
public class FunctionInfo {
-
private final Method _method;
private final Class<?> _clazz;
public FunctionInfo(Method method, Class<?> clazz) {
- super();
- this._method = method;
- this._clazz = clazz;
+ method.setAccessible(true);
+ _method = method;
+ _clazz = clazz;
}
public Method getMethod() {
@@ -39,41 +38,4 @@ public class FunctionInfo {
public Class<?> getClazz() {
return _clazz;
}
-
- /**
- * Check if the Function is applicable to the argumentTypes.
- * We can only know the types at runtime, so we can validate if the return
type is Object.
- * For e.g funcA( funcB('3.14'), columnA)
- * We can only know return type of funcB and 3.14 (String.class) but
- * we cannot know the type of columnA in advance without knowing the source
schema
- * @param argumentTypes
- * @return
- */
- public boolean isApplicable(Class<?>[] argumentTypes) {
-
- Class<?>[] parameterTypes = _method.getParameterTypes();
-
- if (parameterTypes.length != argumentTypes.length) {
- return false;
- }
-
- for (int i = 0; i < parameterTypes.length; i++) {
- Class<?> type = parameterTypes[i];
- //
- if (!type.isAssignableFrom(argumentTypes[i]) && argumentTypes[i] !=
Object.class) {
- return false;
- }
- }
- return true;
- }
-
- /**
- * Eventually we will need to convert the input datatypes before invoking
the actual method. For now, there is no conversion
- *
- * @param args
- * @return
- */
- public Object[] convertTypes(Object[] args) {
- return args;
- }
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java
index 73c12ed..b185d26 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java
@@ -18,78 +18,117 @@
*/
package org.apache.pinot.common.function;
+import com.google.common.base.Preconditions;
import java.lang.reflect.Constructor;
-import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
-import java.util.concurrent.TimeUnit;
-import org.joda.time.Duration;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.pinot.common.utils.PinotDataType;
/**
- * A simple code to invoke a method in any class using reflection.
- * Eventually this will support annotations on the method but for now its a
simple wrapper on any java method
+ * The {@code FunctionInvoker} is a wrapper on a java method which supports
arguments type conversion and method
+ * invocation via reflection.
*/
public class FunctionInvoker {
+ private final Method _method;
+ private final Class<?>[] _parameterClasses;
+ private final PinotDataType[] _parameterTypes;
+ private final Object _instance;
- private static final Logger LOGGER =
LoggerFactory.getLogger(FunctionInvoker.class);
- // Don't log more than 10 entries in 5 MINUTES
- //TODO:Convert this functionality into a class that can be used in other
places
- private static long EXCEPTION_LIMIT_DURATION = TimeUnit.MINUTES.toMillis(5);
- private static long EXCEPTION_LIMIT_RATE = 10;
- private Method _method;
- private Object _instance;
- private int exceptionCount;
- private long lastExceptionTime = 0;
- private FunctionInfo _functionInfo;
-
- public FunctionInvoker(FunctionInfo functionInfo)
- throws Exception {
- _functionInfo = functionInfo;
+ public FunctionInvoker(FunctionInfo functionInfo) {
_method = functionInfo.getMethod();
- _method.setAccessible(true);
- Class<?> clazz = functionInfo.getClazz();
+ Class<?>[] parameterClasses = _method.getParameterTypes();
+ int numParameters = parameterClasses.length;
+ _parameterClasses = new Class<?>[numParameters];
+ _parameterTypes = new PinotDataType[numParameters];
+ for (int i = 0; i < numParameters; i++) {
+ Class<?> parameterClass = parameterClasses[i];
+ _parameterClasses[i] = parameterClass;
+ _parameterTypes[i] = FunctionUtils.getParameterType(parameterClass);
+ }
if (Modifier.isStatic(_method.getModifiers())) {
_instance = null;
} else {
- Constructor<?> constructor = clazz.getDeclaredConstructor();
- constructor.setAccessible(true);
- _instance = constructor.newInstance();
+ Class<?> clazz = functionInfo.getClazz();
+ try {
+ Constructor<?> constructor =
functionInfo.getClazz().getDeclaredConstructor();
+ constructor.setAccessible(true);
+ _instance = constructor.newInstance();
+ } catch (Exception e) {
+ throw new IllegalStateException("Caught exception while constructing
class: " + clazz, e);
+ }
}
}
- public Class<?>[] getParameterTypes() {
- return _method.getParameterTypes();
+ /**
+ * Returns the underlying java method.
+ */
+ public Method getMethod() {
+ return _method;
}
- public Class<?> getReturnType() {
- return _method.getReturnType();
+ /**
+ * Returns the class of the parameters.
+ */
+ public Class<?>[] getParameterClasses() {
+ return _parameterClasses;
}
- public Object process(Object[] args) {
- try {
- return _method.invoke(_instance, _functionInfo.convertTypes(args));
- } catch (IllegalAccessException | IllegalArgumentException |
InvocationTargetException e) {
- //most likely the exception is in the udf, get the exceptio
- Throwable cause = e.getCause();
- if (cause == null) {
- cause = e;
- }
- //some udf's might be configured incorrectly and we dont want to pollute
the log
- //keep track of the last time an exception was logged and reset the
counter if the last exception is more than the EXCEPTION_LIMIT_DURATION
- if (Duration.millis(System.currentTimeMillis() -
lastExceptionTime).getStandardMinutes()
- > EXCEPTION_LIMIT_DURATION) {
- exceptionCount = 0;
+ /**
+ * Returns the PinotDataType of the parameters for type conversion purpose.
Puts {@code null} for the parameter class
+ * that does not support type conversion.
+ */
+ public PinotDataType[] getParameterTypes() {
+ return _parameterTypes;
+ }
+
+ /**
+ * Converts the type of the given arguments to match the parameter classes.
Leaves the argument as is if type
+ * conversion is not needed or supported.
+ */
+ public void convertTypes(Object[] arguments) {
+ int numParameters = _parameterClasses.length;
+ Preconditions.checkArgument(arguments.length == numParameters,
+ "Wrong number of arguments for method: %s, expected: %s, actual: %s",
_method, numParameters, arguments.length);
+ for (int i = 0; i < numParameters; i++) {
+ // Skip conversion for null
+ Object argument = arguments[i];
+ if (argument == null) {
+ continue;
}
- if (exceptionCount < EXCEPTION_LIMIT_RATE) {
- exceptionCount = exceptionCount + 1;
- LOGGER.error("Exception invoking method:{} with args:{}, exception
message: {}", _method.getName(),
- Arrays.toString(args), cause.getMessage());
+ // Skip conversion if argument can be directly assigned
+ Class<?> parameterClass = _parameterClasses[i];
+ Class<?> argumentClass = argument.getClass();
+ if (parameterClass.isAssignableFrom(argumentClass)) {
+ continue;
}
- return null;
+
+ PinotDataType parameterType = _parameterTypes[i];
+ PinotDataType argumentType =
FunctionUtils.getArgumentType(argumentClass);
+ Preconditions.checkArgument(parameterType != null && argumentType !=
null,
+ "Cannot convert value from class: %s to class: %s", argumentClass,
parameterClass);
+ arguments[i] = parameterType.convert(argument, argumentType);
+ }
+ }
+
+ /**
+ * Returns the class of the result value.
+ */
+ public Class<?> getResultClass() {
+ return _method.getReturnType();
+ }
+
+ /**
+ * Invoke the function with the given arguments. The arguments should match
the parameter classes. Use
+ * {@link #convertTypes(Object[])} to convert the argument types if needed
before calling this method.
+ */
+ public Object invoke(Object[] arguments) {
+ try {
+ return _method.invoke(_instance, arguments);
+ } catch (Exception e) {
+ throw new IllegalStateException(
+ "Caught exception while invoking method: " + _method + " with
arguments: " + Arrays.toString(arguments), e);
}
}
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index 712414e..58be240 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -23,6 +23,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
+import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.annotations.ScalarFunction;
import org.reflections.Reflections;
import org.reflections.scanners.MethodAnnotationsScanner;
@@ -104,6 +105,6 @@ public class FunctionRegistry {
}
private static String canonicalize(String functionName) {
- return functionName.toLowerCase();
+ return StringUtils.remove(functionName, '_').toLowerCase();
}
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java
new file mode 100644
index 0000000..33da3cc
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function;
+
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.common.utils.PinotDataType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+
+
+public class FunctionUtils {
+ private FunctionUtils() {
+ }
+
+ // Types allowed as the function parameter (in the function signature) for
type conversion
+ private static final Map<Class<?>, PinotDataType> PARAMETER_TYPE_MAP = new
HashMap<Class<?>, PinotDataType>() {{
+ put(int.class, PinotDataType.INTEGER);
+ put(Integer.class, PinotDataType.INTEGER);
+ put(long.class, PinotDataType.LONG);
+ put(Long.class, PinotDataType.LONG);
+ put(float.class, PinotDataType.FLOAT);
+ put(Float.class, PinotDataType.FLOAT);
+ put(double.class, PinotDataType.DOUBLE);
+ put(Double.class, PinotDataType.DOUBLE);
+ put(String.class, PinotDataType.STRING);
+ put(byte[].class, PinotDataType.BYTES);
+ }};
+
+ // Types allowed as the function argument (actual value passed into the
function) for type conversion
+ private static final Map<Class<?>, PinotDataType> ARGUMENT_TYPE_MAP = new
HashMap<Class<?>, PinotDataType>() {{
+ put(Byte.class, PinotDataType.BYTE);
+ put(Boolean.class, PinotDataType.BOOLEAN);
+ put(Character.class, PinotDataType.CHARACTER);
+ put(Short.class, PinotDataType.SHORT);
+ put(Integer.class, PinotDataType.INTEGER);
+ put(Long.class, PinotDataType.LONG);
+ put(Float.class, PinotDataType.FLOAT);
+ put(Double.class, PinotDataType.DOUBLE);
+ put(String.class, PinotDataType.STRING);
+ put(byte[].class, PinotDataType.BYTES);
+ }};
+
+ private static final Map<Class<?>, DataType> DATA_TYPE_MAP = new
HashMap<Class<?>, DataType>() {{
+ put(int.class, DataType.INT);
+ put(Integer.class, DataType.INT);
+ put(long.class, DataType.LONG);
+ put(Long.class, DataType.LONG);
+ put(float.class, DataType.FLOAT);
+ put(Float.class, DataType.FLOAT);
+ put(double.class, DataType.DOUBLE);
+ put(Double.class, DataType.DOUBLE);
+ put(String.class, DataType.STRING);
+ put(byte[].class, DataType.BYTES);
+ }};
+
+ private static final Map<Class<?>, ColumnDataType> COLUMN_DATA_TYPE_MAP =
new HashMap<Class<?>, ColumnDataType>() {{
+ put(int.class, ColumnDataType.INT);
+ put(Integer.class, ColumnDataType.INT);
+ put(long.class, ColumnDataType.LONG);
+ put(Long.class, ColumnDataType.LONG);
+ put(float.class, ColumnDataType.FLOAT);
+ put(Float.class, ColumnDataType.FLOAT);
+ put(double.class, ColumnDataType.DOUBLE);
+ put(Double.class, ColumnDataType.DOUBLE);
+ put(String.class, ColumnDataType.STRING);
+ put(byte[].class, ColumnDataType.BYTES);
+ }};
+
+ /**
+ * Returns the corresponding PinotDataType for the given parameter class, or
{@code null} if there is no one matching.
+ */
+ @Nullable
+ public static PinotDataType getParameterType(Class<?> clazz) {
+ return PARAMETER_TYPE_MAP.get(clazz);
+ }
+
+ /**
+ * Returns the corresponding PinotDataType for the given argument class, or
{@code null} if there is no one matching.
+ */
+ @Nullable
+ public static PinotDataType getArgumentType(Class<?> clazz) {
+ return ARGUMENT_TYPE_MAP.get(clazz);
+ }
+
+ /**
+ * Returns the corresponding DataType for the given class, or {@code null}
if there is no one matching.
+ */
+ @Nullable
+ public static DataType getDataType(Class<?> clazz) {
+ return DATA_TYPE_MAP.get(clazz);
+ }
+
+ /**
+ * Returns the corresponding ColumnDataType for the given class, or {@code
null} if there is no one matching.
+ */
+ @Nullable
+ public static ColumnDataType getColumnDataType(Class<?> clazz) {
+ return COLUMN_DATA_TYPE_MAP.get(clazz);
+ }
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/annotations/ScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/annotations/ScalarFunction.java
index c788368..e18ec63 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/annotations/ScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/annotations/ScalarFunction.java
@@ -25,13 +25,26 @@ import java.lang.annotation.Target;
/**
- * Annotation Class for Scalar Functions
- * Methods annotated using the interface are registered in the
FunctionsRegistry
- * and can be used as UDFs during Querying
+ * Annotation Class for Scalar Functions.
+ *
+ * Methods annotated using the interface are registered in the
FunctionsRegistry, and can be used for transform and
+ * filtering during record ingestion, and transform and post-aggregation
during query execution.
+ *
+ * NOTE:
+ * 1. The annotated method must be under the package of name
'org.apache.pinot.*.function.*' to be auto-registered.
+ * 2. The following parameter types are supported for auto type conversion:
+ * - int/Integer
+ * - long/Long
+ * - float/Float
+ * - double/Double
+ * - String
+ * - byte[]
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface ScalarFunction {
- boolean enabled() default true;
- String name() default "";
+
+ boolean enabled() default true;
+
+ String name() default "";
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
new file mode 100644
index 0000000..b9bf9b3
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar;
+
+import org.apache.pinot.common.function.annotations.ScalarFunction;
+
+
+/**
+ * Arithmetic scalar functions.
+ */
+public class ArithmeticFunctions {
+ private ArithmeticFunctions() {
+ }
+
+ @ScalarFunction
+ public static double plus(double a, double b) {
+ return a + b;
+ }
+
+ @ScalarFunction
+ public static double minus(double a, double b) {
+ return a - b;
+ }
+
+ @ScalarFunction
+ public static double times(double a, double b) {
+ return a * b;
+ }
+
+ @ScalarFunction
+ public static double divide(double a, double b) {
+ return a / b;
+ }
+
+ @ScalarFunction
+ public static double mod(double a, double b) {
+ return a % b;
+ }
+
+ @ScalarFunction
+ public static double min(double a, double b) {
+ return Double.min(a, b);
+ }
+
+ @ScalarFunction
+ public static double max(double a, double b) {
+ return Double.max(a, b);
+ }
+
+ @ScalarFunction
+ public static double abs(double a) {
+ return Math.abs(a);
+ }
+
+ @ScalarFunction
+ public static double ceil(double a) {
+ return Math.ceil(a);
+ }
+
+ @ScalarFunction
+ public static double floor(double a) {
+ return Math.floor(a);
+ }
+
+ @ScalarFunction
+ public static double exp(double a) {
+ return Math.exp(a);
+ }
+
+ @ScalarFunction
+ public static double ln(double a) {
+ return Math.log(a);
+ }
+
+ @ScalarFunction
+ public static double sqrt(double a) {
+ return Math.sqrt(a);
+ }
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/DateTimeFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java
similarity index 69%
rename from
pinot-common/src/main/java/org/apache/pinot/common/function/DateTimeFunctions.java
rename to
pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java
index ce53773..375ca2b 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/DateTimeFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java
@@ -16,11 +16,11 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.common.function;
+package org.apache.pinot.common.function.scalar;
import java.util.concurrent.TimeUnit;
+import org.apache.pinot.common.function.DateTimePatternHandler;
import org.apache.pinot.common.function.annotations.ScalarFunction;
-import org.joda.time.format.DateTimeFormat;
/**
@@ -65,11 +65,14 @@ import org.joda.time.format.DateTimeFormat;
* </code>
*/
public class DateTimeFunctions {
+ private DateTimeFunctions() {
+ }
+
/**
* Convert epoch millis to epoch seconds
*/
@ScalarFunction
- static Long toEpochSeconds(Long millis) {
+ public static long toEpochSeconds(long millis) {
return TimeUnit.MILLISECONDS.toSeconds(millis);
}
@@ -77,7 +80,7 @@ public class DateTimeFunctions {
* Convert epoch millis to epoch minutes
*/
@ScalarFunction
- static Long toEpochMinutes(Long millis) {
+ public static long toEpochMinutes(long millis) {
return TimeUnit.MILLISECONDS.toMinutes(millis);
}
@@ -85,7 +88,7 @@ public class DateTimeFunctions {
* Convert epoch millis to epoch hours
*/
@ScalarFunction
- static Long toEpochHours(Long millis) {
+ public static long toEpochHours(long millis) {
return TimeUnit.MILLISECONDS.toHours(millis);
}
@@ -93,7 +96,7 @@ public class DateTimeFunctions {
* Convert epoch millis to epoch days
*/
@ScalarFunction
- static Long toEpochDays(Long millis) {
+ public static long toEpochDays(long millis) {
return TimeUnit.MILLISECONDS.toDays(millis);
}
@@ -101,71 +104,71 @@ public class DateTimeFunctions {
* Convert epoch millis to epoch seconds, round to nearest rounding bucket
*/
@ScalarFunction
- static Long toEpochSecondsRounded(Long millis, Number roundToNearest) {
- return (TimeUnit.MILLISECONDS.toSeconds(millis) /
roundToNearest.intValue()) * roundToNearest.intValue();
+ public static long toEpochSecondsRounded(long millis, long roundToNearest) {
+ return (TimeUnit.MILLISECONDS.toSeconds(millis) / roundToNearest) *
roundToNearest;
}
/**
* Convert epoch millis to epoch minutes, round to nearest rounding bucket
*/
@ScalarFunction
- static Long toEpochMinutesRounded(Long millis, Number roundToNearest) {
- return (TimeUnit.MILLISECONDS.toMinutes(millis) /
roundToNearest.intValue()) * roundToNearest.intValue();
+ public static long toEpochMinutesRounded(long millis, long roundToNearest) {
+ return (TimeUnit.MILLISECONDS.toMinutes(millis) / roundToNearest) *
roundToNearest;
}
/**
* Convert epoch millis to epoch hours, round to nearest rounding bucket
*/
@ScalarFunction
- static Long toEpochHoursRounded(Long millis, Number roundToNearest) {
- return (TimeUnit.MILLISECONDS.toHours(millis) / roundToNearest.intValue())
* roundToNearest.intValue();
+ public static long toEpochHoursRounded(long millis, long roundToNearest) {
+ return (TimeUnit.MILLISECONDS.toHours(millis) / roundToNearest) *
roundToNearest;
}
/**
* Convert epoch millis to epoch days, round to nearest rounding bucket
*/
@ScalarFunction
- static Long toEpochDaysRounded(Long millis, Number roundToNearest) {
- return (TimeUnit.MILLISECONDS.toDays(millis) / roundToNearest.intValue())
* roundToNearest.intValue();
+ public static long toEpochDaysRounded(long millis, long roundToNearest) {
+ return (TimeUnit.MILLISECONDS.toDays(millis) / roundToNearest) *
roundToNearest;
}
/**
* Convert epoch millis to epoch seconds, divided by given bucket, to get
nSecondsSinceEpoch
*/
@ScalarFunction
- static Long toEpochSecondsBucket(Long millis, Number bucket) {
- return TimeUnit.MILLISECONDS.toSeconds(millis) / bucket.intValue();
+ public static long toEpochSecondsBucket(long millis, long bucket) {
+ return TimeUnit.MILLISECONDS.toSeconds(millis) / bucket;
}
/**
* Convert epoch millis to epoch minutes, divided by given bucket, to get
nMinutesSinceEpoch
*/
@ScalarFunction
- static Long toEpochMinutesBucket(Long millis, Number bucket) {
- return TimeUnit.MILLISECONDS.toMinutes(millis) / bucket.intValue();
+ public static long toEpochMinutesBucket(long millis, long bucket) {
+ return TimeUnit.MILLISECONDS.toMinutes(millis) / bucket;
}
/**
* Convert epoch millis to epoch hours, divided by given bucket, to get
nHoursSinceEpoch
*/
@ScalarFunction
- static Long toEpochHoursBucket(Long millis, Number bucket) {
- return TimeUnit.MILLISECONDS.toHours(millis) / bucket.intValue();
+ public static long toEpochHoursBucket(long millis, long bucket) {
+ return TimeUnit.MILLISECONDS.toHours(millis) / bucket;
}
/**
* Convert epoch millis to epoch days, divided by given bucket, to get
nDaysSinceEpoch
*/
@ScalarFunction
- static Long toEpochDaysBucket(Long millis, Number bucket) {
- return TimeUnit.MILLISECONDS.toDays(millis) / bucket.intValue();
+ public static long toEpochDaysBucket(long millis, long bucket) {
+ return TimeUnit.MILLISECONDS.toDays(millis) / bucket;
}
/**
* Converts epoch seconds to epoch millis
*/
@ScalarFunction
- static Long fromEpochSeconds(Long seconds) {
+ public static long fromEpochSeconds(long seconds) {
return TimeUnit.SECONDS.toMillis(seconds);
}
@@ -173,63 +176,63 @@ public class DateTimeFunctions {
* Converts epoch minutes to epoch millis
*/
@ScalarFunction
- static Long fromEpochMinutes(Number minutes) {
- return TimeUnit.MINUTES.toMillis(minutes.longValue());
+ public static long fromEpochMinutes(long minutes) {
+ return TimeUnit.MINUTES.toMillis(minutes);
}
/**
* Converts epoch hours to epoch millis
*/
@ScalarFunction
- static Long fromEpochHours(Number hours) {
- return TimeUnit.HOURS.toMillis(hours.longValue());
+ public static long fromEpochHours(long hours) {
+ return TimeUnit.HOURS.toMillis(hours);
}
/**
* Converts epoch days to epoch millis
*/
@ScalarFunction
- static Long fromEpochDays(Number daysSinceEpoch) {
- return TimeUnit.DAYS.toMillis(daysSinceEpoch.longValue());
+ public static long fromEpochDays(long days) {
+ return TimeUnit.DAYS.toMillis(days);
}
/**
* Converts nSecondsSinceEpoch (seconds that have been divided by a bucket),
to epoch millis
*/
@ScalarFunction
- static Long fromEpochSecondsBucket(Long seconds, Number bucket) {
- return TimeUnit.SECONDS.toMillis(seconds * bucket.intValue());
+ public static long fromEpochSecondsBucket(long seconds, long bucket) {
+ return TimeUnit.SECONDS.toMillis(seconds * bucket);
}
/**
* Converts nMinutesSinceEpoch (minutes that have been divided by a bucket),
to epoch millis
*/
@ScalarFunction
- static Long fromEpochMinutesBucket(Number minutes, Number bucket) {
- return TimeUnit.MINUTES.toMillis(minutes.longValue() * bucket.intValue());
+ public static long fromEpochMinutesBucket(long minutes, long bucket) {
+ return TimeUnit.MINUTES.toMillis(minutes * bucket);
}
/**
* Converts nHoursSinceEpoch (hours that have been divided by a bucket), to
epoch millis
*/
@ScalarFunction
- static Long fromEpochHoursBucket(Number hours, Number bucket) {
- return TimeUnit.HOURS.toMillis(hours.longValue() * bucket.intValue());
+ public static long fromEpochHoursBucket(long hours, long bucket) {
+ return TimeUnit.HOURS.toMillis(hours * bucket);
}
/**
* Converts nDaysSinceEpoch (days that have been divided by a bucket), to
epoch millis
*/
@ScalarFunction
- static Long fromEpochDaysBucket(Number daysSinceEpoch, Number bucket) {
- return TimeUnit.DAYS.toMillis(daysSinceEpoch.longValue() *
bucket.intValue());
+ public static long fromEpochDaysBucket(long days, long bucket) {
+ return TimeUnit.DAYS.toMillis(days * bucket);
}
/**
* Converts epoch millis to DateTime string represented by pattern
*/
@ScalarFunction
- static String toDateTime(Long millis, String pattern) {
+ public static String toDateTime(long millis, String pattern) {
return DateTimePatternHandler.parseEpochMillisToDateTimeString(millis,
pattern);
}
@@ -237,26 +240,24 @@ public class DateTimeFunctions {
* Converts DateTime string represented by pattern to epoch millis
*/
@ScalarFunction
- static Long fromDateTime(String dateTimeString, String pattern) {
+ public static long fromDateTime(String dateTimeString, String pattern) {
return
DateTimePatternHandler.parseDateTimeStringToEpochMillis(dateTimeString,
pattern);
}
-
/**
* Round the given time value to nearest multiple
* @return the original value but rounded to the nearest multiple of @param
roundToNearest
*/
@ScalarFunction
- static Long round(Long timeValue, Number roundToNearest) {
- long roundingValue = roundToNearest.longValue();
- return (timeValue / roundingValue) * roundingValue;
+ public static long round(long timeValue, long roundToNearest) {
+ return (timeValue / roundToNearest) * roundToNearest;
}
/**
* Return current time as epoch millis
*/
@ScalarFunction
- static Long now() {
+ public static long now() {
return System.currentTimeMillis();
}
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/JsonFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java
similarity index 90%
rename from
pinot-common/src/main/java/org/apache/pinot/common/function/JsonFunctions.java
rename to
pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java
index e0d3b06..929e3a3 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/JsonFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.common.function;
+package org.apache.pinot.common.function.scalar;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.util.Map;
@@ -36,12 +36,14 @@ import org.apache.pinot.spi.utils.JsonUtils;
* </code>
*/
public class JsonFunctions {
+ private JsonFunctions() {
+ }
/**
* Convert Map to Json String
*/
@ScalarFunction
- static String toJsonMapStr(Map map)
+ public static String toJsonMapStr(Map map)
throws JsonProcessingException {
return JsonUtils.objectToString(map);
}
@@ -50,9 +52,8 @@ public class JsonFunctions {
* Convert object to Json String
*/
@ScalarFunction
- static String json_format(Object object)
+ public static String jsonFormat(Object object)
throws JsonProcessingException {
return JsonUtils.objectToString(object);
}
-
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/StringFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java
similarity index 84%
rename from
pinot-common/src/main/java/org/apache/pinot/common/function/StringFunctions.java
rename to
pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java
index 197f238..141936b 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/StringFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java
@@ -16,8 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.common.function;
-
+package org.apache.pinot.common.function.scalar;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
@@ -33,6 +32,9 @@ import
org.apache.pinot.common.function.annotations.ScalarFunction;
* <code> SELECT UPPER(playerName) FROM baseballStats LIMIT 10 </code>
*/
public class StringFunctions {
+ private StringFunctions() {
+ }
+
private final static Pattern LTRIM = Pattern.compile("^\\s+");
private final static Pattern RTRIM = Pattern.compile("\\s+$");
@@ -42,7 +44,7 @@ public class StringFunctions {
* @return reversed input in from end to start
*/
@ScalarFunction
- static String reverse(String input) {
+ public static String reverse(String input) {
return StringUtils.reverse(input);
}
@@ -52,7 +54,7 @@ public class StringFunctions {
* @return string in lower case format
*/
@ScalarFunction
- static String lower(String input) {
+ public static String lower(String input) {
return input.toLowerCase();
}
@@ -62,7 +64,7 @@ public class StringFunctions {
* @return string in upper case format
*/
@ScalarFunction
- static String upper(String input) {
+ public static String upper(String input) {
return input.toUpperCase();
}
@@ -73,7 +75,7 @@ public class StringFunctions {
* @return substring from beginIndex to end of the parent string
*/
@ScalarFunction
- static String substr(String input, Integer beginIndex) {
+ public static String substr(String input, int beginIndex) {
return input.substring(beginIndex);
}
@@ -88,7 +90,7 @@ public class StringFunctions {
* @return substring from beginIndex to endIndex
*/
@ScalarFunction
- static String substr(String input, Integer beginIndex, Integer endIndex) {
+ public static String substr(String input, int beginIndex, int endIndex) {
if (endIndex == -1) {
return substr(input, beginIndex);
}
@@ -103,7 +105,7 @@ public class StringFunctions {
* @return The two input strings joined by the seperator
*/
@ScalarFunction
- static String concat(String input1, String input2, String seperator) {
+ public static String concat(String input1, String input2, String seperator) {
String result = input1;
result = result + seperator + input2;
return result;
@@ -115,7 +117,7 @@ public class StringFunctions {
* @return trim spaces from both ends of the string
*/
@ScalarFunction
- static String trim(String input) {
+ public static String trim(String input) {
return input.trim();
}
@@ -124,7 +126,7 @@ public class StringFunctions {
* @return trim spaces from left side of the string
*/
@ScalarFunction
- static String ltrim(String input) {
+ public static String ltrim(String input) {
return LTRIM.matcher(input).replaceAll("");
}
@@ -133,7 +135,7 @@ public class StringFunctions {
* @return trim spaces from right side of the string
*/
@ScalarFunction
- static String rtrim(String input) {
+ public static String rtrim(String input) {
return RTRIM.matcher(input).replaceAll("");
}
@@ -143,7 +145,7 @@ public class StringFunctions {
* @return length of the string
*/
@ScalarFunction
- static Integer length(String input) {
+ public static int length(String input) {
return input.length();
}
@@ -156,7 +158,7 @@ public class StringFunctions {
* @return start index of the Nth instance of subtring in main string
*/
@ScalarFunction
- static Integer strpos(String input, String find, Integer instance) {
+ public static int strpos(String input, String find, int instance) {
return StringUtils.ordinalIndexOf(input, find, instance);
}
@@ -167,7 +169,7 @@ public class StringFunctions {
* @return true if string starts with prefix, false o.w.
*/
@ScalarFunction
- static Boolean startsWith(String input, String prefix) {
+ public static boolean startsWith(String input, String prefix) {
return input.startsWith(prefix);
}
@@ -178,7 +180,7 @@ public class StringFunctions {
* @param substitute new substring to be replaced with target
*/
@ScalarFunction
- static String replace(String input, String find, String substitute) {
+ public static String replace(String input, String find, String substitute) {
return input.replaceAll(find, substitute);
}
@@ -190,7 +192,7 @@ public class StringFunctions {
* @return string padded from the right side with pad to reach final size
*/
@ScalarFunction
- static String rpad(String input, Integer size, String pad) {
+ public static String rpad(String input, int size, String pad) {
return StringUtils.rightPad(input, size, pad);
}
@@ -202,7 +204,7 @@ public class StringFunctions {
* @return string padded from the left side with pad to reach final size
*/
@ScalarFunction
- static String lpad(String input, Integer size, String pad) {
+ public static String lpad(String input, int size, String pad) {
return StringUtils.leftPad(input, size, pad);
}
@@ -212,7 +214,7 @@ public class StringFunctions {
* @return the Unicode codepoint of the first character of the string
*/
@ScalarFunction
- static Integer codepoint(String input) {
+ public static int codepoint(String input) {
return input.codePointAt(0);
}
@@ -222,7 +224,7 @@ public class StringFunctions {
* @return the character corresponding to the Unicode codepoint
*/
@ScalarFunction
- static String chr(Integer codepoint) {
+ public static String chr(int codepoint) {
char[] result = Character.toChars(codepoint);
return new String(result);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/PinotDataType.java
b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java
similarity index 95%
rename from
pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/PinotDataType.java
rename to
pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java
index d1296b7..97c017e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/PinotDataType.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java
@@ -16,8 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.core.data.recordtransformer;
+package org.apache.pinot.common.utils;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.BytesUtils;
@@ -632,4 +633,33 @@ public enum PinotDataType {
"Unsupported data type: " + dataType + " in field: " +
fieldSpec.getName());
}
}
+
+ public static PinotDataType getPinotDataType(ColumnDataType columnDataType) {
+ switch (columnDataType) {
+ case INT:
+ return INTEGER;
+ case LONG:
+ return LONG;
+ case FLOAT:
+ return FLOAT;
+ case DOUBLE:
+ return DOUBLE;
+ case STRING:
+ return STRING;
+ case BYTES:
+ return BYTES;
+ case INT_ARRAY:
+ return INTEGER_ARRAY;
+ case LONG_ARRAY:
+ return LONG_ARRAY;
+ case FLOAT_ARRAY:
+ return FLOAT_ARRAY;
+ case DOUBLE_ARRAY:
+ return DOUBLE_ARRAY;
+ case STRING_ARRAY:
+ return STRING_ARRAY;
+ default:
+ throw new IllegalStateException("Cannot convert ColumnDataType: " +
columnDataType + " to PinotDataType");
+ }
+ }
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 99e62fb..80ceae2 100644
---
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -849,7 +849,8 @@ public class CalciteSqlParser {
}
try {
FunctionInvoker invoker = new FunctionInvoker(functionInfo);
- Object result = invoker.process(arguments);
+ invoker.convertTypes(arguments);
+ Object result = invoker.invoke(arguments);
return RequestUtils.getLiteralExpression(result);
} catch (Exception e) {
throw new SqlCompilationException(new
IllegalArgumentException("Unsupported function - " + funcName, e));
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/data/recordtransformer/PinotDataTypeTest.java
b/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java
similarity index 98%
rename from
pinot-core/src/test/java/org/apache/pinot/core/data/recordtransformer/PinotDataTypeTest.java
rename to
pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java
index 674bfbd..7f4200f 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/data/recordtransformer/PinotDataTypeTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java
@@ -16,11 +16,11 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.core.data.recordtransformer;
+package org.apache.pinot.common.utils;
import org.testng.annotations.Test;
-import static org.apache.pinot.core.data.recordtransformer.PinotDataType.*;
+import static org.apache.pinot.common.utils.PinotDataType.*;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;
diff --git
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 956531d..25203f2 100644
---
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -23,6 +23,7 @@ import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.TimeUnit;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.pinot.common.function.AggregationFunctionType;
@@ -1600,45 +1601,55 @@ public class CalciteSqlCompilerTest {
@Test
public void testCompileTimeExpression()
throws SqlParseException {
- // True
long lowerBound = System.currentTimeMillis();
Expression expression = CalciteSqlParser.compileToExpression("now()");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getLiteral() != null);
- long nowTs = expression.getLiteral().getLongValue();
+ Assert.assertNotNull(expression.getLiteral());
long upperBound = System.currentTimeMillis();
- Assert.assertTrue(nowTs >= lowerBound);
- Assert.assertTrue(nowTs <= upperBound);
- expression = CalciteSqlParser.compileToExpression("toDateTime(now(),
'yyyy-MM-dd z')");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ long result = expression.getLiteral().getLongValue();
+ Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+ lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
+ expression = CalciteSqlParser.compileToExpression("to_epoch_hours(now() +
3600000)");
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getLiteral() != null);
- String today = expression.getLiteral().getStringValue();
- String expectedTodayStr =
-
Instant.now().atZone(ZoneId.of("UTC")).format(DateTimeFormatter.ofPattern("yyyy-MM-dd
z"));
- Assert.assertEquals(today, expectedTodayStr);
- expression =
CalciteSqlParser.compileToExpression("toDateTime(playerName)");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getLiteral());
+ upperBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
+ result = expression.getLiteral().getLongValue();
+ Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+ expression =
CalciteSqlParser.compileToExpression("toDateTime(millisSinceEpoch)");
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(),
"TODATETIME");
-
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"playerName");
+ Assert
+
.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"millisSinceEpoch");
+
expression = CalciteSqlParser.compileToExpression("reverse(playerName)");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(), "REVERSE");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"playerName");
+
expression = CalciteSqlParser.compileToExpression("reverse('playerName')");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getLiteral() != null);
+ Assert.assertNotNull(expression.getLiteral());
Assert.assertEquals(expression.getLiteral().getFieldValue(), "emaNreyalp");
+
+ expression = CalciteSqlParser.compileToExpression("reverse(123)");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ Assert.assertEquals(expression.getLiteral().getFieldValue(), "321");
+
expression = CalciteSqlParser.compileToExpression("count(*)");
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
expression =
CalciteSqlParser.invokeCompileTimeFunctionExpression(expression);
- Assert.assertTrue(expression.getFunctionCall() != null);
+ Assert.assertNotNull(expression.getFunctionCall());
Assert.assertEquals(expression.getFunctionCall().getOperator(), "COUNT");
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"*");
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/FunctionEvaluatorFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/FunctionEvaluatorFactory.java
index 5e9d749..6b3db98 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/FunctionEvaluatorFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/FunctionEvaluatorFactory.java
@@ -87,18 +87,11 @@ public class FunctionEvaluatorFactory {
}
public static FunctionEvaluator getExpressionEvaluator(String
transformExpression) {
- FunctionEvaluator functionEvaluator;
- try {
- if
(transformExpression.startsWith(GroovyFunctionEvaluator.getGroovyExpressionPrefix()))
{
- functionEvaluator = new GroovyFunctionEvaluator(transformExpression);
- } else {
- functionEvaluator = new InbuiltFunctionEvaluator(transformExpression);
- }
- } catch (Exception e) {
- throw new IllegalStateException(
- "Could not construct FunctionEvaluator for transformFunction: " +
transformExpression, e);
+ if
(transformExpression.startsWith(GroovyFunctionEvaluator.getGroovyExpressionPrefix()))
{
+ return new GroovyFunctionEvaluator(transformExpression);
+ } else {
+ return new InbuiltFunctionEvaluator(transformExpression);
}
- return functionEvaluator;
}
private static String getDefaultMapKeysTransformExpression(String
mapColumnName) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluator.java
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluator.java
index 11368c6..bacc800 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluator.java
@@ -20,9 +20,7 @@ package org.apache.pinot.core.data.function;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
-import org.apache.commons.lang3.math.NumberUtils;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
@@ -55,8 +53,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
private final ExecutableNode _rootNode;
private final List<String> _arguments;
- public InbuiltFunctionEvaluator(String functionExpression)
- throws Exception {
+ public InbuiltFunctionEvaluator(String functionExpression) {
_arguments = new ArrayList<>();
ExpressionContext expression =
QueryContextConverterUtils.getExpression(functionExpression);
Preconditions
@@ -65,11 +62,9 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
_rootNode = planExecution(expression.getFunction());
}
- private ExecutableNode planExecution(FunctionContext function)
- throws Exception {
+ private FunctionExecutionNode planExecution(FunctionContext function) {
List<ExpressionContext> arguments = function.getArguments();
int numArguments = arguments.size();
- Class<?>[] argumentTypes = new Class<?>[numArguments];
ExecutableNode[] childNodes = new ExecutableNode[numArguments];
for (int i = 0; i < numArguments; i++) {
ExpressionContext argument = arguments.get(i);
@@ -90,13 +85,9 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
throw new IllegalStateException();
}
childNodes[i] = childNode;
- argumentTypes[i] = childNode.getReturnType();
}
- String functionName = function.getFunctionName();
- FunctionInfo functionInfo =
FunctionRegistry.getFunctionByName(functionName);
- Preconditions.checkState(functionInfo != null &&
functionInfo.isApplicable(argumentTypes),
- "Failed to find function of name: %s with argument types: %s",
functionName, Arrays.toString(argumentTypes));
+ FunctionInfo functionInfo =
FunctionRegistry.getFunctionByName(function.getFunctionName());
return new FunctionExecutionNode(functionInfo, childNodes);
}
@@ -105,6 +96,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
return _arguments;
}
+ @Override
public Object evaluate(GenericRow row) {
return _rootNode.execute(row);
}
@@ -112,63 +104,47 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
private interface ExecutableNode {
Object execute(GenericRow row);
-
- Class<?> getReturnType();
}
private static class FunctionExecutionNode implements ExecutableNode {
- FunctionInvoker _functionInvoker;
- ExecutableNode[] _argumentProviders;
- Object[] _argInputs;
+ final FunctionInvoker _functionInvoker;
+ final ExecutableNode[] _argumentNodes;
+ final Object[] _arguments;
- public FunctionExecutionNode(FunctionInfo functionInfo, ExecutableNode[]
argumentProviders)
- throws Exception {
+ FunctionExecutionNode(FunctionInfo functionInfo, ExecutableNode[]
argumentNodes) {
_functionInvoker = new FunctionInvoker(functionInfo);
- _argumentProviders = argumentProviders;
- _argInputs = new Object[_argumentProviders.length];
+ _argumentNodes = argumentNodes;
+ _arguments = new Object[_argumentNodes.length];
}
+ @Override
public Object execute(GenericRow row) {
- for (int i = 0; i < _argumentProviders.length; i++) {
- _argInputs[i] = _argumentProviders[i].execute(row);
+ int numArguments = _argumentNodes.length;
+ for (int i = 0; i < numArguments; i++) {
+ _arguments[i] = _argumentNodes[i].execute(row);
}
- return _functionInvoker.process(_argInputs);
- }
-
- public Class<?> getReturnType() {
- return _functionInvoker.getReturnType();
+ _functionInvoker.convertTypes(_arguments);
+ return _functionInvoker.invoke(_arguments);
}
}
private static class ConstantExecutionNode implements ExecutableNode {
- private Object _value;
- private Class<?> _returnType;
-
- public ConstantExecutionNode(String value) {
- if (NumberUtils.isCreatable(value)) {
- _value = NumberUtils.createNumber(value);
- _returnType = Number.class;
- } else {
- _value = value;
- _returnType = String.class;
- }
- }
+ final String _value;
- @Override
- public Object execute(GenericRow row) {
- return _value;
+ ConstantExecutionNode(String value) {
+ _value = value;
}
@Override
- public Class<?> getReturnType() {
- return _returnType;
+ public String execute(GenericRow row) {
+ return _value;
}
}
private static class ColumnExecutionNode implements ExecutableNode {
- private String _column;
+ final String _column;
- public ColumnExecutionNode(String column) {
+ ColumnExecutionNode(String column) {
_column = column;
}
@@ -176,10 +152,5 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
public Object execute(GenericRow row) {
return row.getValue(_column);
}
-
- @Override
- public Class<?> getReturnType() {
- return Object.class;
- }
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/DataTypeTransformer.java
b/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/DataTypeTransformer.java
index ccdda33..4ab665c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/DataTypeTransformer.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/data/recordtransformer/DataTypeTransformer.java
@@ -27,6 +27,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.data.readers.GenericRow;
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java
index 8b2c77d..958b570 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java
@@ -19,43 +19,51 @@
package org.apache.pinot.core.operator.transform.function;
import com.google.common.base.Preconditions;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
-import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
+import org.apache.pinot.common.function.FunctionUtils;
+import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.core.common.DataSource;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.plan.DocIdSetPlanNode;
-import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+/**
+ * Wrapper transform function on the annotated scalar function.
+ */
public class ScalarTransformFunctionWrapper extends BaseTransformFunction {
+ private final String _name;
+ private final FunctionInvoker _functionInvoker;
+
+ private Object[] _arguments;
+ private int _numNonLiteralArguments;
+ private int[] _nonLiteralIndices;
+ private TransformFunction[] _nonLiteralFunctions;
+ private Object[][] _nonLiteralValues;
+ private TransformResultMetadata _resultMetadata;
- private FunctionInvoker _functionInvoker;
- private String _name;
- private Object[] _args;
- private List<Integer> _nonLiteralArgIndices;
- private List<FieldSpec.DataType> _nonLiteralArgType;
- private List<TransformFunction> _nonLiteralTransformFunction;
- private TransformResultMetadata _transformResultMetadata;
- private String[] _stringResult;
- private int[] _integerResult;
- private float[] _floatResult;
- private double[] _doubleResult;
- private long[] _longResult;
+ private int[] _intResults;
+ private float[] _floatResults;
+ private double[] _doubleResults;
+ private long[] _longResults;
+ private String[] _stringResults;
+ private byte[][] _bytesResults;
- public ScalarTransformFunctionWrapper(String functionName, FunctionInfo info)
- throws Exception {
- _nonLiteralArgIndices = new ArrayList<>();
- _nonLiteralArgType = new ArrayList<>();
- _nonLiteralTransformFunction = new ArrayList<>();
- _name = functionName;
- _functionInvoker = new FunctionInvoker(info);
+ public ScalarTransformFunctionWrapper(FunctionInfo functionInfo) {
+ _name = functionInfo.getMethod().getName();
+ _functionInvoker = new FunctionInvoker(functionInfo);
+ Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ int numParameters = parameterClasses.length;
+ for (int i = 0; i < numParameters; i++) {
+ Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported
parameter class: %s for method: %s",
+ parameterClasses[i], functionInfo.getMethod());
+ }
}
@Override
@@ -65,235 +73,184 @@ public class ScalarTransformFunctionWrapper extends
BaseTransformFunction {
@Override
public void init(List<TransformFunction> arguments, Map<String, DataSource>
dataSourceMap) {
- Integer numArguments = arguments.size();
- Preconditions.checkArgument(numArguments ==
_functionInvoker.getParameterTypes().length,
- "The number of arguments are not same for scalar function and
transform function: %s", getName());
+ int numArguments = arguments.size();
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ Preconditions.checkArgument(numArguments == parameterTypes.length,
+ "Wrong number of arguments for method: %s, expected: %s, actual: %s",
_functionInvoker.getMethod(),
+ parameterTypes.length, numArguments);
- _args = new Object[numArguments];
+ _arguments = new Object[numArguments];
+ _nonLiteralIndices = new int[numArguments];
+ _nonLiteralFunctions = new TransformFunction[numArguments];
for (int i = 0; i < numArguments; i++) {
- TransformFunction function = arguments.get(i);
- if (function instanceof LiteralTransformFunction) {
- String literal = ((LiteralTransformFunction) function).getLiteral();
- Class paramType = _functionInvoker.getParameterTypes()[i];
- switch (paramType.getTypeName()) {
- case "java.lang.Integer":
- _args[i] = Integer.parseInt(literal);
- break;
- case "java.lang.Long":
- _args[i] = Long.valueOf(literal);
- break;
- case "java.lang.Float":
- _args[i] = Float.valueOf(literal);
- break;
- case "java.lang.Double":
- _args[i] = Double.valueOf(literal);
- break;
- case "java.lang.String":
- _args[i] = literal;
- break;
- default:
- throw new RuntimeException(
- "Unsupported data type " + paramType.getTypeName() + "for
transform function " + getName());
- }
+ TransformFunction transformFunction = arguments.get(i);
+ if (transformFunction instanceof LiteralTransformFunction) {
+ String literal = ((LiteralTransformFunction)
transformFunction).getLiteral();
+ _arguments[i] = parameterTypes[i].convert(literal,
PinotDataType.STRING);
} else {
- _nonLiteralArgIndices.add(i);
- _nonLiteralTransformFunction.add(function);
- Class paramType = _functionInvoker.getParameterTypes()[i];
-
- switch (paramType.getTypeName()) {
- case "java.lang.Integer":
- _nonLiteralArgType.add(FieldSpec.DataType.INT);
- break;
- case "java.lang.Long":
- _nonLiteralArgType.add(FieldSpec.DataType.LONG);
- break;
- case "java.lang.Float":
- _nonLiteralArgType.add(FieldSpec.DataType.FLOAT);
- break;
- case "java.lang.Double":
- _nonLiteralArgType.add(FieldSpec.DataType.DOUBLE);
- break;
- case "java.lang.String":
- _nonLiteralArgType.add(FieldSpec.DataType.STRING);
- break;
- default:
- throw new RuntimeException(
- "Unsupported data type " + paramType.getTypeName() + "for
transform function " + getName());
- }
+ _nonLiteralIndices[_numNonLiteralArguments] = i;
+ _nonLiteralFunctions[_numNonLiteralArguments] = transformFunction;
+ _numNonLiteralArguments++;
}
}
+ _nonLiteralValues = new Object[_numNonLiteralArguments][];
- Class returnType = _functionInvoker.getReturnType();
- switch (returnType.getTypeName()) {
- case "java.lang.Integer":
- _transformResultMetadata = INT_SV_NO_DICTIONARY_METADATA;
- break;
- case "java.lang.Long":
- _transformResultMetadata = LONG_SV_NO_DICTIONARY_METADATA;
- break;
- case "java.lang.Float":
- case "java.lang.Double":
- _transformResultMetadata = DOUBLE_SV_NO_DICTIONARY_METADATA;
- break;
- case "java.lang.Boolean":
- case "java.lang.String":
- _transformResultMetadata = STRING_SV_NO_DICTIONARY_METADATA;
- break;
- default:
- throw new RuntimeException(
- "Unsupported data type " + returnType.getTypeName() + "for
transform function " + getName());
+ DataType resultDataType =
FunctionUtils.getDataType(_functionInvoker.getResultClass());
+ // Handle unrecognized result class with STRING
+ if (resultDataType == null) {
+ resultDataType = DataType.STRING;
}
+ _resultMetadata = new TransformResultMetadata(resultDataType, true, false);
}
@Override
public TransformResultMetadata getResultMetadata() {
- return _transformResultMetadata;
+ return _resultMetadata;
}
- @SuppressWarnings("Duplicates")
@Override
public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
- if (_integerResult == null) {
- _integerResult = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ if (_resultMetadata.getDataType() != DataType.INT) {
+ return super.transformToIntValuesSV(projectionBlock);
}
+ if (_intResults == null) {
+ _intResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ getNonLiteralValues(projectionBlock);
int length = projectionBlock.getNumDocs();
- int numNonLiteralArgs = _nonLiteralArgIndices.size();
- Object[][] nonLiteralBlockValues = new Object[numNonLiteralArgs][];
-
- transformNonLiteralArgsToValues(projectionBlock, numNonLiteralArgs,
nonLiteralBlockValues);
-
- //now invoke the actual function
for (int i = 0; i < length; i++) {
- for (int k = 0; k < numNonLiteralArgs; k++) {
- _args[_nonLiteralArgIndices.get(k)] = nonLiteralBlockValues[k][i];
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
- _integerResult[i] = (Integer) _functionInvoker.process(_args);
+ _intResults[i] = (int) _functionInvoker.invoke(_arguments);
}
- return _integerResult;
+ return _intResults;
}
- @SuppressWarnings("Duplicates")
@Override
public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
- if (_longResult == null) {
- _longResult = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ if (_resultMetadata.getDataType() != DataType.LONG) {
+ return super.transformToLongValuesSV(projectionBlock);
+ }
+ if (_longResults == null) {
+ _longResults = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
}
+ getNonLiteralValues(projectionBlock);
int length = projectionBlock.getNumDocs();
- int numNonLiteralArgs = _nonLiteralArgIndices.size();
- Object[][] nonLiteralBlockValues = new Object[numNonLiteralArgs][];
-
- transformNonLiteralArgsToValues(projectionBlock, numNonLiteralArgs,
nonLiteralBlockValues);
-
- //now invoke the actual function
for (int i = 0; i < length; i++) {
- for (int k = 0; k < numNonLiteralArgs; k++) {
- _args[_nonLiteralArgIndices.get(k)] = nonLiteralBlockValues[k][i];
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
- _longResult[i] = (Long) _functionInvoker.process(_args);
+ _longResults[i] = (long) _functionInvoker.invoke(_arguments);
}
- return _longResult;
+ return _longResults;
}
- @SuppressWarnings("Duplicates")
@Override
public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
- if (_floatResult == null) {
- _floatResult = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ if (_resultMetadata.getDataType() != DataType.FLOAT) {
+ return super.transformToFloatValuesSV(projectionBlock);
+ }
+ if (_floatResults == null) {
+ _floatResults = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
}
+ getNonLiteralValues(projectionBlock);
int length = projectionBlock.getNumDocs();
- int numNonLiteralArgs = _nonLiteralArgIndices.size();
- Object[][] nonLiteralBlockValues = new Object[numNonLiteralArgs][];
-
- transformNonLiteralArgsToValues(projectionBlock, numNonLiteralArgs,
nonLiteralBlockValues);
-
- //now invoke the actual function
for (int i = 0; i < length; i++) {
- for (int k = 0; k < numNonLiteralArgs; k++) {
- _args[_nonLiteralArgIndices.get(k)] = nonLiteralBlockValues[k][i];
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
- _floatResult[i] = (Float) _functionInvoker.process(_args);
+ _floatResults[i] = (float) _functionInvoker.invoke(_arguments);
}
- return _floatResult;
+ return _floatResults;
}
- @SuppressWarnings("Duplicates")
@Override
public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
- if (_doubleResult == null) {
- _doubleResult = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ if (_resultMetadata.getDataType() != DataType.DOUBLE) {
+ return super.transformToDoubleValuesSV(projectionBlock);
}
+ if (_doubleResults == null) {
+ _doubleResults = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ getNonLiteralValues(projectionBlock);
int length = projectionBlock.getNumDocs();
- int numNonLiteralArgs = _nonLiteralArgIndices.size();
- Object[][] nonLiteralBlockValues = new Object[numNonLiteralArgs][];
-
- transformNonLiteralArgsToValues(projectionBlock, numNonLiteralArgs,
nonLiteralBlockValues);
-
- //now invoke the actual function
for (int i = 0; i < length; i++) {
- for (int k = 0; k < numNonLiteralArgs; k++) {
- _args[_nonLiteralArgIndices.get(k)] = nonLiteralBlockValues[k][i];
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
- _doubleResult[i] = (Double) _functionInvoker.process(_args);
+ _doubleResults[i] = (double) _functionInvoker.invoke(_arguments);
}
- return _doubleResult;
+ return _doubleResults;
}
- @SuppressWarnings("Duplicates")
@Override
public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
- if (_stringResult == null) {
- _stringResult = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ if (_resultMetadata.getDataType() != DataType.STRING) {
+ return super.transformToStringValuesSV(projectionBlock);
}
-
+ if (_stringResults == null) {
+ _stringResults = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ getNonLiteralValues(projectionBlock);
int length = projectionBlock.getNumDocs();
- int numNonLiteralArgs = _nonLiteralArgIndices.size();
- Object[][] nonLiteralBlockValues = new Object[numNonLiteralArgs][];
-
- transformNonLiteralArgsToValues(projectionBlock, numNonLiteralArgs,
nonLiteralBlockValues);
-
- //now invoke the actual function
for (int i = 0; i < length; i++) {
- for (int k = 0; k < numNonLiteralArgs; k++) {
- _args[_nonLiteralArgIndices.get(k)] = nonLiteralBlockValues[k][i];
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
- _stringResult[i] = String.valueOf(_functionInvoker.process(_args));
+ _stringResults[i] = _functionInvoker.invoke(_arguments).toString();
}
+ return _stringResults;
+ }
- return _stringResult;
+ @Override
+ public byte[][] transformToBytesValuesSV(ProjectionBlock projectionBlock) {
+ if (_resultMetadata.getDataType() != DataType.BYTES) {
+ return super.transformToBytesValuesSV(projectionBlock);
+ }
+ if (_bytesResults == null) {
+ _bytesResults = new byte[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ getNonLiteralValues(projectionBlock);
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numNonLiteralArguments; j++) {
+ _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
+ }
+ _bytesResults[i] = (byte[]) _functionInvoker.invoke(_arguments);
+ }
+ return _bytesResults;
}
- private void transformNonLiteralArgsToValues(ProjectionBlock
projectionBlock, int numNonLiteralArgs,
- Object[][] nonLiteralBlockValues) {
- for (int i = 0; i < numNonLiteralArgs; i++) {
- TransformFunction transformFunc = _nonLiteralTransformFunction.get(i);
- FieldSpec.DataType returnType = _nonLiteralArgType.get(i);
- switch (returnType) {
- case STRING:
- nonLiteralBlockValues[i] =
transformFunc.transformToStringValuesSV(projectionBlock);
+ /**
+ * Helper method to fetch values for the non-literal transform functions
based on the parameter types.
+ */
+ private void getNonLiteralValues(ProjectionBlock projectionBlock) {
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ for (int i = 0; i < _numNonLiteralArguments; i++) {
+ int index = _nonLiteralIndices[i];
+ TransformFunction transformFunction = _nonLiteralFunctions[i];
+ switch (parameterTypes[index]) {
+ case INTEGER:
+ _nonLiteralValues[i] =
ArrayUtils.toObject(transformFunction.transformToIntValuesSV(projectionBlock));
break;
- case INT:
- int[] values = transformFunc.transformToIntValuesSV(projectionBlock);
- nonLiteralBlockValues[i] =
Arrays.stream(values).boxed().toArray(Integer[]::new);
+ case LONG:
+ _nonLiteralValues[i] =
ArrayUtils.toObject(transformFunction.transformToLongValuesSV(projectionBlock));
+ break;
+ case FLOAT:
+ _nonLiteralValues[i] =
ArrayUtils.toObject(transformFunction.transformToFloatValuesSV(projectionBlock));
break;
case DOUBLE:
- double[] doubleValues =
transformFunc.transformToDoubleValuesSV(projectionBlock);
- nonLiteralBlockValues[i] =
Arrays.stream(doubleValues).boxed().toArray(Double[]::new);
+ _nonLiteralValues[i] =
ArrayUtils.toObject(transformFunction.transformToDoubleValuesSV(projectionBlock));
break;
- case FLOAT:
- float[] floatValues =
transformFunc.transformToFloatValuesSV(projectionBlock);
- Float[] floatObjectValues = new Float[floatValues.length];
- for (int j = 0; j < floatValues.length; j++) {
- floatObjectValues[j] = floatValues[j];
- }
- nonLiteralBlockValues[i] = floatObjectValues;
+ case STRING:
+ _nonLiteralValues[i] =
transformFunction.transformToStringValuesSV(projectionBlock);
break;
- case LONG:
- long[] longValues =
transformFunc.transformToLongValuesSV(projectionBlock);
- nonLiteralBlockValues[i] =
Arrays.stream(longValues).boxed().toArray(Long[]::new);
+ case BYTES:
+ _nonLiteralValues[i] =
transformFunction.transformToBytesValuesSV(projectionBlock);
break;
default:
- throw new RuntimeException(
- "Unsupported return data type " + returnType + "for transform
function " + getName());
+ throw new IllegalStateException();
}
}
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index a4f9548..e307a4b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -27,15 +27,15 @@ import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.core.common.DataSource;
-import org.apache.pinot.core.geospatial.transform.function.StContainsFunction;
-import org.apache.pinot.core.geospatial.transform.function.StDistanceFunction;
import org.apache.pinot.core.geospatial.transform.function.StAreaFunction;
import org.apache.pinot.core.geospatial.transform.function.StAsBinaryFunction;
+import org.apache.pinot.core.geospatial.transform.function.StAsTextFunction;
+import org.apache.pinot.core.geospatial.transform.function.StContainsFunction;
+import org.apache.pinot.core.geospatial.transform.function.StDistanceFunction;
import org.apache.pinot.core.geospatial.transform.function.StEqualsFunction;
import
org.apache.pinot.core.geospatial.transform.function.StGeogFromTextFunction;
import
org.apache.pinot.core.geospatial.transform.function.StGeogFromWKBFunction;
import
org.apache.pinot.core.geospatial.transform.function.StGeomFromTextFunction;
-import org.apache.pinot.core.geospatial.transform.function.StAsTextFunction;
import
org.apache.pinot.core.geospatial.transform.function.StGeomFromWKBFunction;
import
org.apache.pinot.core.geospatial.transform.function.StGeometryTypeFunction;
import org.apache.pinot.core.geospatial.transform.function.StPointFunction;
@@ -176,12 +176,7 @@ public class TransformFunctionFactory {
if (functionInfo == null) {
throw new BadQueryRequestException("Unsupported transform
function: " + functionName);
}
- try {
- transformFunction = new
ScalarTransformFunctionWrapper(functionName, functionInfo);
- } catch (Exception e) {
- throw new RuntimeException("Caught exception while constructing
scalar transform function: " + functionName,
- e);
- }
+ transformFunction = new ScalarTransformFunctionWrapper(functionInfo);
}
List<ExpressionContext> arguments = function.getArguments();
List<TransformFunction> transformFunctionArguments = new
ArrayList<>(arguments.size());
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
new file mode 100644
index 0000000..9e6d46a
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.postaggregation;
+
+import com.google.common.base.Preconditions;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.function.FunctionInvoker;
+import org.apache.pinot.common.function.FunctionRegistry;
+import org.apache.pinot.common.function.FunctionUtils;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.common.utils.PinotDataType;
+
+
+/**
+ * Post-aggregation function on the annotated scalar function.
+ */
+public class PostAggregationFunction {
+ private final FunctionInvoker _functionInvoker;
+ private final PinotDataType[] _argumentTypes;
+ private final ColumnDataType _resultType;
+
+ public PostAggregationFunction(String functionName, ColumnDataType[]
argumentTypes) {
+ FunctionInfo functionInfo =
FunctionRegistry.getFunctionByName(functionName);
+ Preconditions.checkArgument(functionInfo != null, "Unsupported function:
%s", functionName);
+ _functionInvoker = new FunctionInvoker(functionInfo);
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ int numArguments = argumentTypes.length;
+ Preconditions.checkArgument(numArguments == parameterTypes.length,
+ "Wrong number of arguments for method: %s, expected: %s, actual: %s",
functionInfo.getMethod(),
+ parameterTypes.length, numArguments);
+ _argumentTypes = new PinotDataType[numArguments];
+ for (int i = 0; i < numArguments; i++) {
+ _argumentTypes[i] = PinotDataType.getPinotDataType(argumentTypes[i]);
+ }
+ ColumnDataType resultType =
FunctionUtils.getColumnDataType(_functionInvoker.getResultClass());
+ // Handle unrecognized result class with STRING
+ _resultType = resultType != null ? resultType : ColumnDataType.STRING;
+ }
+
+ /**
+ * Returns the ColumnDataType of the result.
+ */
+ public ColumnDataType getResultType() {
+ return _resultType;
+ }
+
+ /**
+ * Invoke the function with the given arguments.
+ * NOTE: The passed in arguments could be modified during the type
conversion.
+ */
+ public Object invoke(Object[] arguments) {
+ int numArguments = arguments.length;
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ for (int i = 0; i < numArguments; i++) {
+ PinotDataType parameterType = parameterTypes[i];
+ PinotDataType argumentType = _argumentTypes[i];
+ if (parameterType != argumentType) {
+ arguments[i] = parameterType.convert(arguments[i], argumentType);
+ }
+ }
+ Object result = _functionInvoker.invoke(arguments);
+ return _resultType == ColumnDataType.STRING ? result.toString() : result;
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java
index dabef53..df122c8 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java
@@ -35,41 +35,45 @@ import org.testng.annotations.Test;
*/
public class InbuiltFunctionsTest {
- @Test(dataProvider = "dateTimeFunctionsTestDataProvider")
- public void testDateTimeTransformFunctions(String transformFunction,
List<String> arguments, GenericRow row,
- Object result)
- throws Exception {
- InbuiltFunctionEvaluator evaluator = new
InbuiltFunctionEvaluator(transformFunction);
- Assert.assertEquals(evaluator.getArguments(), arguments);
- Assert.assertEquals(evaluator.evaluate(row), result);
+ private void testFunction(String functionExpression, List<String>
expectedArguments, GenericRow row,
+ Object expectedResult) {
+ InbuiltFunctionEvaluator evaluator = new
InbuiltFunctionEvaluator(functionExpression);
+ Assert.assertEquals(evaluator.getArguments(), expectedArguments);
+ Assert.assertEquals(evaluator.evaluate(row), expectedResult);
}
- @DataProvider(name = "dateTimeFunctionsTestDataProvider")
+ @Test(dataProvider = "dateTimeFunctionsDataProvider")
+ public void testDateTimeFunctions(String functionExpression, List<String>
expectedArguments, GenericRow row,
+ Object expectedResult) {
+ testFunction(functionExpression, expectedArguments, row, expectedResult);
+ }
+
+ @DataProvider(name = "dateTimeFunctionsDataProvider")
public Object[][] dateTimeFunctionsDataProvider() {
List<Object[]> inputs = new ArrayList<>();
+
// round epoch millis to nearest 15 minutes
GenericRow row0_0 = new GenericRow();
row0_0.putValue("timestamp", 1578685189000L);
// round to 15 minutes, but keep in milliseconds: Fri Jan 10 2020 19:39:49
becomes Fri Jan 10 2020 19:30:00
- inputs.add(new Object[]{"round(timestamp, 900000)", Lists.newArrayList(
- "timestamp"), row0_0, 1578684600000L});
+ inputs.add(new Object[]{"round(timestamp, 900000)",
Lists.newArrayList("timestamp"), row0_0, 1578684600000L});
- // toEpochSeconds
+ // toEpochSeconds (with type conversion)
GenericRow row1_0 = new GenericRow();
- row1_0.putValue("timestamp", 1578685189000L);
+ row1_0.putValue("timestamp", 1578685189000.0);
inputs.add(new Object[]{"toEpochSeconds(timestamp)",
Lists.newArrayList("timestamp"), row1_0, 1578685189L});
- // toEpochSeconds w/ rounding
+ // toEpochSeconds w/ rounding (with type conversion)
GenericRow row1_1 = new GenericRow();
- row1_1.putValue("timestamp", 1578685189000L);
+ row1_1.putValue("timestamp", "1578685189000");
inputs.add(
new Object[]{"toEpochSecondsRounded(timestamp, 10)",
Lists.newArrayList("timestamp"), row1_1, 1578685180L});
- // toEpochSeconds w/ bucketing
+ // toEpochSeconds w/ bucketing (with underscore in function name)
GenericRow row1_2 = new GenericRow();
row1_2.putValue("timestamp", 1578685189000L);
- inputs
- .add(new Object[]{"toEpochSecondsBucket(timestamp, 10)",
Lists.newArrayList("timestamp"), row1_2, 157868518L});
+ inputs.add(
+ new Object[]{"to_epoch_seconds_bucket(timestamp, 10)",
Lists.newArrayList("timestamp"), row1_2, 157868518L});
// toEpochMinutes
GenericRow row2_0 = new GenericRow();
@@ -215,15 +219,13 @@ public class InbuiltFunctionsTest {
return inputs.toArray(new Object[0][]);
}
- @Test(dataProvider = "jsonFunctionDataProvider")
- public void testJsonFunctions(String transformFunction, List<String>
arguments, GenericRow row, Object result)
- throws Exception {
- InbuiltFunctionEvaluator evaluator = new
InbuiltFunctionEvaluator(transformFunction);
- Assert.assertEquals(evaluator.getArguments(), arguments);
- Assert.assertEquals(evaluator.evaluate(row), result);
+ @Test(dataProvider = "jsonFunctionsDataProvider")
+ public void testJsonFunctions(String functionExpression, List<String>
expectedArguments, GenericRow row,
+ Object expectedResult) {
+ testFunction(functionExpression, expectedArguments, row, expectedResult);
}
- @DataProvider(name = "jsonFunctionDataProvider")
+ @DataProvider(name = "jsonFunctionsDataProvider")
public Object[][] jsonFunctionsDataProvider()
throws IOException {
List<Object[]> inputs = new ArrayList<>();
@@ -262,4 +264,37 @@ public class InbuiltFunctionsTest {
return inputs.toArray(new Object[0][]);
}
+
+ @Test(dataProvider = "arithmeticFunctionsDataProvider")
+ public void testArithmeticFunctions(String functionExpression, List<String>
expectedArguments, GenericRow row,
+ Object expectedResult) {
+ testFunction(functionExpression, expectedArguments, row, expectedResult);
+ }
+
+ @DataProvider(name = "arithmeticFunctionsDataProvider")
+ public Object[][] arithmeticFunctionsDataProvider() {
+ List<Object[]> inputs = new ArrayList<>();
+
+ GenericRow row0 = new GenericRow();
+ row0.putValue("a", (byte) 1);
+ row0.putValue("b", (char) 2);
+ inputs.add(new Object[]{"plus(a, b)", Lists.newArrayList("a", "b"), row0,
3.0});
+
+ GenericRow row1 = new GenericRow();
+ row1.putValue("a", (short) 3);
+ row1.putValue("b", 4);
+ inputs.add(new Object[]{"minus(a, b)", Lists.newArrayList("a", "b"), row1,
-1.0});
+
+ GenericRow row2 = new GenericRow();
+ row2.putValue("a", 5L);
+ row2.putValue("b", 6f);
+ inputs.add(new Object[]{"times(a, b)", Lists.newArrayList("a", "b"), row2,
30.0});
+
+ GenericRow row3 = new GenericRow();
+ row3.putValue("a", 7.0);
+ row3.putValue("b", "8");
+ inputs.add(new Object[]{"divide(a, b)", Lists.newArrayList("a", "b"),
row3, 0.875});
+
+ return inputs.toArray(new Object[0][]);
+ }
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
index 911f107..394d285 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
@@ -44,7 +44,7 @@ public class ScalarTransformFunctionWrapperTest extends
BaseTransformFunctionTes
@Test
public void testStringUpperTransformFunction() {
ExpressionContext expression =
- QueryContextConverterUtils.getExpression(String.format("upper(%s)",
STRING_ALPHANUM_SV_COLUMN));
+ QueryContextConverterUtils.getExpression(String.format("UPPER(%s)",
STRING_ALPHANUM_SV_COLUMN));
TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof
ScalarTransformFunctionWrapper);
Assert.assertEquals(transformFunction.getName(), "upper");
@@ -58,7 +58,7 @@ public class ScalarTransformFunctionWrapperTest extends
BaseTransformFunctionTes
@Test
public void testStringReverseTransformFunction() {
ExpressionContext expression =
- QueryContextConverterUtils.getExpression(String.format("reverse(%s)",
STRING_ALPHANUM_SV_COLUMN));
+ QueryContextConverterUtils.getExpression(String.format("rEvErSe(%s)",
STRING_ALPHANUM_SV_COLUMN));
TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof
ScalarTransformFunctionWrapper);
Assert.assertEquals(transformFunction.getName(), "reverse");
@@ -72,7 +72,7 @@ public class ScalarTransformFunctionWrapperTest extends
BaseTransformFunctionTes
@Test
public void testStringSubStrTransformFunction() {
ExpressionContext expression =
- QueryContextConverterUtils.getExpression(String.format("substr(%s, 0,
2)", STRING_ALPHANUM_SV_COLUMN));
+ QueryContextConverterUtils.getExpression(String.format("sub_str(%s, 0,
2)", STRING_ALPHANUM_SV_COLUMN));
TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof
ScalarTransformFunctionWrapper);
Assert.assertEquals(transformFunction.getName(), "substr");
@@ -83,7 +83,7 @@ public class ScalarTransformFunctionWrapperTest extends
BaseTransformFunctionTes
testTransformFunction(transformFunction, expectedValues);
expression =
- QueryContextConverterUtils.getExpression(String.format("substr(%s, 2,
-1)", STRING_ALPHANUM_SV_COLUMN));
+ QueryContextConverterUtils.getExpression(String.format("substr(%s,
'2', '-1')", STRING_ALPHANUM_SV_COLUMN));
transformFunction = TransformFunctionFactory.get(expression,
_dataSourceMap);
Assert.assertTrue(transformFunction instanceof
ScalarTransformFunctionWrapper);
Assert.assertEquals(transformFunction.getName(), "substr");
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
new file mode 100644
index 0000000..1c0dedd
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.postaggregation;
+
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.geospatial.GeometryUtils;
+import org.apache.pinot.core.geospatial.serde.GeometrySerializer;
+import org.locationtech.jts.geom.Coordinate;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class PostAggregationFunctionTest {
+
+ @Test
+ public void testPostAggregationFunction() {
+ // Plus
+ PostAggregationFunction function =
+ new PostAggregationFunction("plus", new
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG});
+ assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
+ assertEquals(function.invoke(new Object[]{1, 2L}), 3.0);
+
+ // Minus
+ function = new PostAggregationFunction("MINUS", new
ColumnDataType[]{ColumnDataType.FLOAT, ColumnDataType.DOUBLE});
+ assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
+ assertEquals(function.invoke(new Object[]{3f, 4.0}), -1.0);
+
+ // Times
+ function = new PostAggregationFunction("tImEs", new
ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT});
+ assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
+ assertEquals(function.invoke(new Object[]{"5", 6}), 30.0);
+
+ // Reverse
+ function = new PostAggregationFunction("reverse", new
ColumnDataType[]{ColumnDataType.LONG});
+ assertEquals(function.getResultType(), ColumnDataType.STRING);
+ assertEquals(function.invoke(new Object[]{"1234567890"}), "0987654321");
+
+ // ST_AsText
+ function = new PostAggregationFunction("ST_AsText", new
ColumnDataType[]{ColumnDataType.BYTES});
+ assertEquals(function.getResultType(), ColumnDataType.STRING);
+ assertEquals(function.invoke(
+ new
Object[]{GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new
Coordinate(10, 20)))}),
+ "POINT (10 20)");
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]