This is an automated email from the ASF dual-hosted git repository. chenyz pushed a commit to branch udaf in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 2bb60b23bccdb4b464f8f18034da656913be1ca2 Author: Chen YZ <[email protected]> AuthorDate: Thu Dec 5 20:59:53 2024 +0800 save --- .../apache/iotdb/udf/AggregateFunctionExample.java | 129 +++++++++++++++++++ .../config/AggregateFunctionConfig.java} | 18 ++- .../udf/api/relational/AggregateFunction.java | 93 +++++++++++++- .../relational/aggregation/AccumulatorFactory.java | 37 +++++- .../UserDefinedAggregateFunctionAccumulator.java | 108 ++++++++++++++++ .../GroupedUserDefinedAggregateAccumulator.java | 137 +++++++++++++++++++++ .../aggregation/grouped/array/StateBigArray.java | 61 +++++++++ .../relational/analyzer/ExpressionTreeUtils.java | 5 +- .../relational/metadata/TableMetadataImpl.java | 25 +++- .../relational/planner/optimizations/Util.java | 2 - .../TableBuiltinAggregationFunction.java | 4 +- .../commons/udf/utils/UDFDataTypeTransformer.java | 43 ++++++- 12 files changed, 647 insertions(+), 15 deletions(-) diff --git a/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java b/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java new file mode 100644 index 00000000000..a4e2a45c0d1 --- /dev/null +++ b/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.udf; + +import org.apache.iotdb.udf.api.State; +import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig; +import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters; +import org.apache.iotdb.udf.api.exception.UDFException; +import org.apache.iotdb.udf.api.exception.UDFParameterNotValidException; +import org.apache.iotdb.udf.api.relational.AggregateFunction; +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.type.Type; +import org.apache.iotdb.udf.api.utils.ResultValue; + +import org.apache.tsfile.block.column.Column; + +import java.nio.ByteBuffer; + +/** + * This is an internal example of the AggregateFunction implementation. + * + * <p>CREATE DATABASE test; + * + * <p>USE test; + * + * <p>CREATE TABLE t1(device_id STRING ID, s1 TEXT MEASUREMENT, s2 INT32 MEASUREMENT); + * + * <p>INSERT INTO t1(time, device_id, s1, s2) VALUES (1, 'd1', 'a', 1), (2, 'd1', null, 2), (3, + * 'd2', 'c', null); + * + * <p>CREATE FUNCTION my_count AS 'org.apache.iotdb.udf.AggregateFunctionExample'; + * + * <p>SHOW FUNCTIONS; + * + * <p>SELECT time, device_id, my_count(s1) as s1_count, my_count(s2) as s2_count FROM t1 group by + * device_id; + * + * <p>SELECT time, my_count(s1) as s1_count, my_count(s2) as s2_count FROM t1; + */ +public class AggregateFunctionExample implements AggregateFunction { + + static class CountState implements State { + + long count; + + @Override + public void reset() { + count = 0; + } + + @Override + public byte[] serialize() { + ByteBuffer buffer = ByteBuffer.allocate(Double.BYTES); + buffer.putLong(count); + return buffer.array(); + } + + @Override + public void deserialize(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + count = buffer.getLong(); + } + } + + @Override + public void validate(FunctionParameters parameters) throws UDFException { + if (parameters.getChildExpressionsSize() != 1) { + throw new UDFParameterNotValidException("Only one parameter is required."); + } + } + + @Override + public void beforeStart(FunctionParameters parameters, AggregateFunctionConfig configurations) { + configurations.setOutputDataType(Type.INT64); + } + + @Override + public State createState() { + return new CountState(); + } + + @Override + public void addInput(State state, Column[] columns) { + CountState countState = (CountState) state; + for (int i = 0; i < columns[0].getPositionCount(); i++) { + if (!columns[0].isNull(i)) { + countState.count++; + } + } + } + + @Override + public void addInput(State state, Record input) { + CountState countState = (CountState) state; + if (!input.isNull(0)) { + countState.count++; + } + } + + @Override + public void combineState(State state, State rhs) { + CountState countState = (CountState) state; + CountState rhsCountState = (CountState) rhs; + countState.count += rhsCountState.count; + } + + @Override + public void outputFinal(State state, ResultValue resultValue) { + CountState countState = (CountState) state; + resultValue.setLong(countState.count); + } +} diff --git a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java similarity index 63% copy from iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java copy to iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java index 24942afd010..9b830a54d08 100644 --- a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java +++ b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java @@ -17,6 +17,20 @@ * under the License. */ -package org.apache.iotdb.udf.api.relational; +package org.apache.iotdb.udf.api.customizer.config; -public interface AggregateFunction extends SQLFunction {} +import org.apache.iotdb.udf.api.type.Type; + +public class AggregateFunctionConfig extends UDFConfigurations { + + /** + * Set the output data type of the scalar function. + * + * @param outputDataType the output data type of the scalar function + * @return this + */ + public AggregateFunctionConfig setOutputDataType(Type outputDataType) { + this.outputDataType = outputDataType; + return this; + } +} diff --git a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java index 24942afd010..e3dada2e2c6 100644 --- a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java +++ b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java @@ -19,4 +19,95 @@ package org.apache.iotdb.udf.api.relational; -public interface AggregateFunction extends SQLFunction {} +import org.apache.iotdb.udf.api.State; +import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig; +import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters; +import org.apache.iotdb.udf.api.exception.UDFException; +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.utils.ResultValue; + +import org.apache.tsfile.block.column.Column; + +public interface AggregateFunction extends SQLFunction { + + /** + * This method is used to validate {@linkplain FunctionParameters}. + * + * @param parameters parameters used to validate + * @throws UDFException if any parameter is not valid + */ + void validate(FunctionParameters parameters) throws UDFException; + + /** + * This method is mainly used to initialize {@linkplain AggregateFunction} and set the output data + * type. In this method, the user need to do the following things: + * + * <ul> + * <li>Use {@linkplain FunctionParameters} to get input data types and infer output data type. + * <li>Use {@linkplain FunctionParameters} to get necessary attributes. + * <li>Set the output data type in {@linkplain AggregateFunctionConfig}. + * </ul> + * + * <p>This method is called after the AggregateFunction is instantiated and before the beginning + * of the transformation process. + * + * @param parameters used to parse the input parameters entered by the user + * @param configurations used to set the required properties in the ScalarFunction + */ + void beforeStart(FunctionParameters parameters, AggregateFunctionConfig configurations); + + /** Create and initialize state. You may bind some resource in this method. */ + State createState(); + + /** + * Batch update state with data columns. You shall iterate columns and update state with raw + * values TODO:should delete this interface + * + * @param state state to be updated + * @param columns input columns from IoTDB TsBlock, time column is always the last column, the + * remaining columns are their parameter value columns + */ + void addInput(State state, Column[] columns); + + /** + * Batch update state with data columns. You shall iterate columns and update state with raw + * values + * + * @param state state to be updated + * @param input input columns from IoTDB TsBlock, time column is always the last column, the + * remaining columns are their parameter value columns + */ + void addInput(State state, Record input); + + /** + * Merge two state in execution engine. + * + * @param state current state + * @param rhs right-hand-side state to be merged + */ + void combineState(State state, State rhs); + + /** + * Calculate output value from final state + * + * @param state final state + * @param resultValue used to collect output data points + */ + void outputFinal(State state, ResultValue resultValue); + + /** + * This method is optional Remove partial state from current state. Implement this method to + * enable sliding window feature. + * + * @param state current state + * @param removed state to be removed + */ + default void removeState(State state, State removed) { + throw new UnsupportedOperationException(getClass().getName()); + } + + /** This method is mainly used to release the resources used in the SQLFunction. */ + default void beforeDestroy() { + // do nothing + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 125fa756302..17ded3b0a5c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -20,6 +20,8 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; import org.apache.iotdb.common.rpc.thrift.TAggregationType; +import org.apache.iotdb.commons.udf.utils.TableUDFUtils; +import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator; @@ -35,13 +37,18 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggr import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig; +import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters; +import org.apache.iotdb.udf.api.relational.AggregateFunction; import org.apache.tsfile.enums.TSDataType; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkState; import static org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction.FIRST_BY; @@ -60,7 +67,7 @@ public class AccumulatorFactory { String timeColumnName) { if (aggregationType == TAggregationType.UDAF) { // If UDAF accumulator receives raw input, it needs to check input's attribute - throw new UnsupportedOperationException(); + return createUDAFAccumulator(functionName, inputDataTypes, inputAttributes); } else if ((LAST_BY.getFunctionName().equals(functionName) || FIRST_BY.getFunctionName().equals(functionName)) && inputExpressions.size() > 1) { @@ -99,13 +106,39 @@ public class AccumulatorFactory { boolean ascending) { if (aggregationType == TAggregationType.UDAF) { // If UDAF accumulator receives raw input, it needs to check input's attribute - throw new UnsupportedOperationException(); + return createGroupedUDAFAccumulator(functionName, inputDataTypes, inputAttributes); } else { return createBuiltinGroupedAccumulator( aggregationType, inputDataTypes, inputExpressions, inputAttributes, ascending); } } + private static TableAccumulator createUDAFAccumulator( + String functionName, List<TSDataType> inputDataTypes, Map<String, String> inputAttributes) { + AggregateFunction aggregateFunction = TableUDFUtils.getAggregateFunction(functionName); + FunctionParameters functionParameters = + new FunctionParameters( + UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes); + AggregateFunctionConfig config = new AggregateFunctionConfig(); + aggregateFunction.beforeStart(functionParameters, config); + return new UserDefinedAggregateFunctionAccumulator(aggregateFunction); + } + + private static GroupedAccumulator createGroupedUDAFAccumulator( + String functionName, List<TSDataType> inputDataTypes, Map<String, String> inputAttributes) { + AggregateFunction aggregateFunction = TableUDFUtils.getAggregateFunction(functionName); + FunctionParameters functionParameters = + new FunctionParameters( + UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes); + AggregateFunctionConfig config = new AggregateFunctionConfig(); + aggregateFunction.beforeStart(functionParameters, config); + return new GroupedUserDefinedAggregateAccumulator( + aggregateFunction, + inputDataTypes.stream() + .map(UDFDataTypeTransformer::transformTSDataTypeToReadType) + .collect(Collectors.toList())); + } + private static GroupedAccumulator createBuiltinGroupedAccumulator( TAggregationType aggregationType, List<TSDataType> inputDataTypes, diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java new file mode 100644 index 00000000000..9c3f0c78168 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.udf.api.State; +import org.apache.iotdb.udf.api.relational.AggregateFunction; +import org.apache.iotdb.udf.api.utils.ResultValue; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; + +import static com.google.common.base.Preconditions.checkArgument; + +public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(UserDefinedAggregateFunctionAccumulator.class); + private final AggregateFunction aggregateFunction; + private final State state; + + public UserDefinedAggregateFunctionAccumulator(AggregateFunction aggregateFunction) { + this.aggregateFunction = aggregateFunction; + this.state = aggregateFunction.createState(); + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new UserDefinedAggregateFunctionAccumulator(aggregateFunction); + } + + @Override + public void addInput(Column[] arguments) { + aggregateFunction.addInput(state, arguments); + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output of UDAF should be BinaryColumn"); + State otherState = aggregateFunction.createState(); + Binary otherStateBinary = argument.getBinary(0); + otherState.deserialize(otherStateBinary.getValues()); + + aggregateFunction.combineState(state, otherState); + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output of UDAF should be BinaryColumn"); + byte[] bytes = state.serialize(); + columnBuilder.writeBinary(new Binary(bytes)); + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + ResultValue resultValue = new ResultValue(columnBuilder); + aggregateFunction.outputFinal(state, resultValue); + } + + @Override + public boolean hasFinalResult() { + // TODO + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + // TODO + } + + @Override + public void reset() { + state.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java new file mode 100644 index 00000000000..659bbd0a01a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.commons.udf.access.RecordIterator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.ObjectBigArray; +import org.apache.iotdb.udf.api.State; +import org.apache.iotdb.udf.api.relational.AggregateFunction; +import org.apache.iotdb.udf.api.utils.ResultValue; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.read.common.type.Type; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; + +import java.util.Arrays; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedUserDefinedAggregateAccumulator.class); + private final AggregateFunction aggregateFunction; + private final ObjectBigArray<State> stateArray; + private final List<Type> inputDataTypes; + + public GroupedUserDefinedAggregateAccumulator( + AggregateFunction aggregateFunction, List<Type> inputDataTypes) { + this.aggregateFunction = aggregateFunction; + this.stateArray = new ObjectBigArray<>(); + this.inputDataTypes = inputDataTypes; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public void setGroupCount(long groupCount) { + stateArray.ensureCapacity(groupCount); + } + + private State getOrCreateState(int groupId) { + State state = stateArray.get(groupId); + if (state == null) { + state = aggregateFunction.createState(); + stateArray.set(groupId, state); + } + return state; + } + + @Override + public void addInput(int[] groupIds, Column[] arguments) { + RecordIterator iterator = + new RecordIterator( + Arrays.asList(arguments), inputDataTypes, arguments[0].getPositionCount()); + int index = 0; + while (iterator.hasNext()) { + int groupId = groupIds[index++]; + State state = getOrCreateState(groupId); + if (state == null) { + state = aggregateFunction.createState(); + stateArray.set(groupId, state); + } + aggregateFunction.addInput(state, iterator.next()); + } + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output of UDAF should be BinaryColumn"); + + for (int i = 0; i < groupIds.length; i++) { + if (!argument.isNull(i)) { + State otherState = aggregateFunction.createState(); + Binary otherStateBinary = argument.getBinary(i); + otherState.deserialize(otherStateBinary.getValues()); + aggregateFunction.combineState(getOrCreateState(groupIds[i]), otherState); + } + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output of UDAF should be BinaryColumn"); + if (stateArray.get(groupId) == null) { + columnBuilder.writeBinary(new Binary(new byte[0])); + } else { + byte[] bytes = stateArray.get(groupId).serialize(); + columnBuilder.writeBinary(new Binary(bytes)); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + ResultValue resultValue = new ResultValue(columnBuilder); + aggregateFunction.outputFinal(getOrCreateState(groupId), resultValue); + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + stateArray.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java new file mode 100644 index 00000000000..e4e70e49552 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array; + +import org.apache.iotdb.udf.api.State; + +import org.apache.tsfile.block.column.Column; + +import java.util.function.Supplier; + +public class StateBigArray { + + private final ObjectBigArray<State> array; + private final Supplier<State> stateSupplier; + + public StateBigArray(Supplier<State> stateSupplier) { + this.array = new ObjectBigArray<>(); + this.stateSupplier = stateSupplier; + } + + public State get(long index) { + return array.get(index); + } + + public void set(long index, State value) { + array.set(index, value); + } + + public void update(long index, Column[] arguments) { + State state = array.get(index); + if (state == null) { + state = stateSupplier.get(); + array.set(index, state); + } + } + + public void ensureCapacity(long length) { + array.ensureCapacity(length); + } + + public void reset() { + array.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java index f786bb1ceff..f8cec46dfd1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java @@ -20,6 +20,7 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; +import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultExpressionTraversalVisitor; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; @@ -92,8 +93,8 @@ public final class ExpressionTreeUtils { } static boolean isAggregationFunction(String functionName) { - // TODO consider UDAF return TableBuiltinAggregationFunction.getBuiltInAggregateFunctionName() - .contains(functionName.toLowerCase()); + .contains(functionName.toLowerCase()) + || TableUDFUtils.isAggregateFunction(functionName); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 26e68ab5cad..e89931c8d88 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -49,8 +49,10 @@ import org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundExceptio import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature; import org.apache.iotdb.db.schemaengine.table.DataNodeTableCache; import org.apache.iotdb.db.utils.constant.SqlConstant; +import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig; import org.apache.iotdb.udf.api.customizer.config.ScalarFunctionConfig; import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters; +import org.apache.iotdb.udf.api.relational.AggregateFunction; import org.apache.iotdb.udf.api.relational.ScalarFunction; import org.apache.tsfile.file.metadata.IDeviceID; @@ -650,10 +652,26 @@ public class TableMetadataImpl implements Metadata { } finally { scalarFunction.beforeDestroy(); } + } else if (TableUDFUtils.isAggregateFunction(functionName)) { + AggregateFunction aggregateFunction = TableUDFUtils.getAggregateFunction(functionName); + FunctionParameters functionParameters = + new FunctionParameters( + argumentTypes.stream() + .map(UDFDataTypeTransformer::transformReadTypeToUDFDataType) + .collect(Collectors.toList()), + Collections.emptyMap()); + try { + aggregateFunction.validate(functionParameters); + AggregateFunctionConfig config = new AggregateFunctionConfig(); + aggregateFunction.beforeStart(functionParameters, config); + return UDFDataTypeTransformer.transformUDFDataTypeToReadType(config.getOutputDataType()); + } catch (Exception e) { + throw new SemanticException("Invalid function parameters: " + e.getMessage()); + } finally { + aggregateFunction.beforeDestroy(); + } } - // TODO UDAF - throw new SemanticException("Unknown function: " + functionName); } @@ -661,7 +679,8 @@ public class TableMetadataImpl implements Metadata { public boolean isAggregationFunction( final SessionInfo session, final String functionName, final AccessControl accessControl) { return TableBuiltinAggregationFunction.getBuiltInAggregateFunctionName() - .contains(functionName.toLowerCase(Locale.ENGLISH)); + .contains(functionName.toLowerCase(Locale.ENGLISH)) + || TableUDFUtils.isAggregateFunction(functionName); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java index 844744dbe79..8fff6f2eb23 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java @@ -62,8 +62,6 @@ public class Util { resolvedFunction.getSignature().getArgumentTypes()); Symbol intermediateSymbol = symbolAllocator.newSymbol(resolvedFunction.getSignature().getName(), intermediateType); - // TODO put symbol and its type to TypeProvide or later process: add all map contents of - // SymbolAllocator to the TypeProvider checkState( !originalAggregation.getOrderingScheme().isPresent(), "Aggregate with ORDER BY does not support partial aggregation"); diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 49db27110c7..4cd046dfeb6 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import static org.apache.tsfile.read.common.type.BlobType.BLOB; import static org.apache.tsfile.read.common.type.DoubleType.DOUBLE; import static org.apache.tsfile.read.common.type.LongType.INT64; @@ -103,7 +104,8 @@ public enum TableBuiltinAggregationFunction { case "min": return originalArgumentTypes.get(0); default: - throw new IllegalArgumentException("Invalid Aggregation function: " + name); + // default is UDAF + return BLOB; } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java index 1cae52ea5b9..c7ca2301953 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java @@ -22,11 +22,15 @@ import org.apache.iotdb.udf.api.type.Type; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.type.BinaryType; +import org.apache.tsfile.read.common.type.BlobType; import org.apache.tsfile.read.common.type.BooleanType; +import org.apache.tsfile.read.common.type.DateType; import org.apache.tsfile.read.common.type.DoubleType; import org.apache.tsfile.read.common.type.FloatType; import org.apache.tsfile.read.common.type.IntType; import org.apache.tsfile.read.common.type.LongType; +import org.apache.tsfile.read.common.type.StringType; +import org.apache.tsfile.read.common.type.TimestampType; import java.util.List; import java.util.stream.Collectors; @@ -121,19 +125,54 @@ public class UDFDataTypeTransformer { case BOOLEAN: return BooleanType.BOOLEAN; case INT32: - case DATE: return IntType.INT32; + case DATE: + return DateType.DATE; case INT64: - case TIMESTAMP: return LongType.INT64; + case TIMESTAMP: + return TimestampType.TIMESTAMP; case FLOAT: return FloatType.FLOAT; case DOUBLE: return DoubleType.DOUBLE; case TEXT: + return BinaryType.TEXT; case BLOB: + return BlobType.BLOB; case STRING: + return StringType.STRING; + default: + throw new IllegalArgumentException("Invalid input: " + type); + } + } + + public static org.apache.tsfile.read.common.type.Type transformTSDataTypeToReadType( + TSDataType type) { + if (type == null) { + return null; + } + switch (type) { + case BOOLEAN: + return BooleanType.BOOLEAN; + case INT32: + return IntType.INT32; + case DATE: + return DateType.DATE; + case INT64: + return LongType.INT64; + case TIMESTAMP: + return TimestampType.TIMESTAMP; + case FLOAT: + return FloatType.FLOAT; + case DOUBLE: + return DoubleType.DOUBLE; + case TEXT: return BinaryType.TEXT; + case BLOB: + return BlobType.BLOB; + case STRING: + return StringType.STRING; default: throw new IllegalArgumentException("Invalid input: " + type); }
