This is an automated email from the ASF dual-hosted git repository.

chenyz pushed a commit to branch udaf
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 2bb60b23bccdb4b464f8f18034da656913be1ca2
Author: Chen YZ <[email protected]>
AuthorDate: Thu Dec 5 20:59:53 2024 +0800

    save
---
 .../apache/iotdb/udf/AggregateFunctionExample.java | 129 +++++++++++++++++++
 .../config/AggregateFunctionConfig.java}           |  18 ++-
 .../udf/api/relational/AggregateFunction.java      |  93 +++++++++++++-
 .../relational/aggregation/AccumulatorFactory.java |  37 +++++-
 .../UserDefinedAggregateFunctionAccumulator.java   | 108 ++++++++++++++++
 .../GroupedUserDefinedAggregateAccumulator.java    | 137 +++++++++++++++++++++
 .../aggregation/grouped/array/StateBigArray.java   |  61 +++++++++
 .../relational/analyzer/ExpressionTreeUtils.java   |   5 +-
 .../relational/metadata/TableMetadataImpl.java     |  25 +++-
 .../relational/planner/optimizations/Util.java     |   2 -
 .../TableBuiltinAggregationFunction.java           |   4 +-
 .../commons/udf/utils/UDFDataTypeTransformer.java  |  43 ++++++-
 12 files changed, 647 insertions(+), 15 deletions(-)

diff --git 
a/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java 
b/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java
new file mode 100644
index 00000000000..a4e2a45c0d1
--- /dev/null
+++ 
b/example/udf/src/main/java/org/apache/iotdb/udf/AggregateFunctionExample.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.udf;
+
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.exception.UDFParameterNotValidException;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+
+import java.nio.ByteBuffer;
+
+/**
+ * This is an internal example of the AggregateFunction implementation.
+ *
+ * <p>CREATE DATABASE test;
+ *
+ * <p>USE test;
+ *
+ * <p>CREATE TABLE t1(device_id STRING ID, s1 TEXT MEASUREMENT, s2 INT32 
MEASUREMENT);
+ *
+ * <p>INSERT INTO t1(time, device_id, s1, s2) VALUES (1, 'd1', 'a', 1), (2, 
'd1', null, 2), (3,
+ * 'd2', 'c', null);
+ *
+ * <p>CREATE FUNCTION my_count AS 
'org.apache.iotdb.udf.AggregateFunctionExample';
+ *
+ * <p>SHOW FUNCTIONS;
+ *
+ * <p>SELECT time, device_id, my_count(s1) as s1_count, my_count(s2) as 
s2_count FROM t1 group by
+ * device_id;
+ *
+ * <p>SELECT time, my_count(s1) as s1_count, my_count(s2) as s2_count FROM t1;
+ */
+public class AggregateFunctionExample implements AggregateFunction {
+
+  static class CountState implements State {
+
+    long count;
+
+    @Override
+    public void reset() {
+      count = 0;
+    }
+
+    @Override
+    public byte[] serialize() {
+      ByteBuffer buffer = ByteBuffer.allocate(Double.BYTES);
+      buffer.putLong(count);
+      return buffer.array();
+    }
+
+    @Override
+    public void deserialize(byte[] bytes) {
+      ByteBuffer buffer = ByteBuffer.wrap(bytes);
+      count = buffer.getLong();
+    }
+  }
+
+  @Override
+  public void validate(FunctionParameters parameters) throws UDFException {
+    if (parameters.getChildExpressionsSize() != 1) {
+      throw new UDFParameterNotValidException("Only one parameter is 
required.");
+    }
+  }
+
+  @Override
+  public void beforeStart(FunctionParameters parameters, 
AggregateFunctionConfig configurations) {
+    configurations.setOutputDataType(Type.INT64);
+  }
+
+  @Override
+  public State createState() {
+    return new CountState();
+  }
+
+  @Override
+  public void addInput(State state, Column[] columns) {
+    CountState countState = (CountState) state;
+    for (int i = 0; i < columns[0].getPositionCount(); i++) {
+      if (!columns[0].isNull(i)) {
+        countState.count++;
+      }
+    }
+  }
+
+  @Override
+  public void addInput(State state, Record input) {
+    CountState countState = (CountState) state;
+    if (!input.isNull(0)) {
+      countState.count++;
+    }
+  }
+
+  @Override
+  public void combineState(State state, State rhs) {
+    CountState countState = (CountState) state;
+    CountState rhsCountState = (CountState) rhs;
+    countState.count += rhsCountState.count;
+  }
+
+  @Override
+  public void outputFinal(State state, ResultValue resultValue) {
+    CountState countState = (CountState) state;
+    resultValue.setLong(countState.count);
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java
similarity index 63%
copy from 
iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
copy to 
iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java
index 24942afd010..9b830a54d08 100644
--- 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/AggregateFunctionConfig.java
@@ -17,6 +17,20 @@
  * under the License.
  */
 
-package org.apache.iotdb.udf.api.relational;
+package org.apache.iotdb.udf.api.customizer.config;
 
-public interface AggregateFunction extends SQLFunction {}
+import org.apache.iotdb.udf.api.type.Type;
+
+public class AggregateFunctionConfig extends UDFConfigurations {
+
+  /**
+   * Set the output data type of the scalar function.
+   *
+   * @param outputDataType the output data type of the scalar function
+   * @return this
+   */
+  public AggregateFunctionConfig setOutputDataType(Type outputDataType) {
+    this.outputDataType = outputDataType;
+    return this;
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
index 24942afd010..e3dada2e2c6 100644
--- 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java
@@ -19,4 +19,95 @@
 
 package org.apache.iotdb.udf.api.relational;
 
-public interface AggregateFunction extends SQLFunction {}
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+
+public interface AggregateFunction extends SQLFunction {
+
+  /**
+   * This method is used to validate {@linkplain FunctionParameters}.
+   *
+   * @param parameters parameters used to validate
+   * @throws UDFException if any parameter is not valid
+   */
+  void validate(FunctionParameters parameters) throws UDFException;
+
+  /**
+   * This method is mainly used to initialize {@linkplain AggregateFunction} 
and set the output data
+   * type. In this method, the user need to do the following things:
+   *
+   * <ul>
+   *   <li>Use {@linkplain FunctionParameters} to get input data types and 
infer output data type.
+   *   <li>Use {@linkplain FunctionParameters} to get necessary attributes.
+   *   <li>Set the output data type in {@linkplain AggregateFunctionConfig}.
+   * </ul>
+   *
+   * <p>This method is called after the AggregateFunction is instantiated and 
before the beginning
+   * of the transformation process.
+   *
+   * @param parameters used to parse the input parameters entered by the user
+   * @param configurations used to set the required properties in the 
ScalarFunction
+   */
+  void beforeStart(FunctionParameters parameters, AggregateFunctionConfig 
configurations);
+
+  /** Create and initialize state. You may bind some resource in this method. 
*/
+  State createState();
+
+  /**
+   * Batch update state with data columns. You shall iterate columns and 
update state with raw
+   * values TODO:should delete this interface
+   *
+   * @param state state to be updated
+   * @param columns input columns from IoTDB TsBlock, time column is always 
the last column, the
+   *     remaining columns are their parameter value columns
+   */
+  void addInput(State state, Column[] columns);
+
+  /**
+   * Batch update state with data columns. You shall iterate columns and 
update state with raw
+   * values
+   *
+   * @param state state to be updated
+   * @param input input columns from IoTDB TsBlock, time column is always the 
last column, the
+   *     remaining columns are their parameter value columns
+   */
+  void addInput(State state, Record input);
+
+  /**
+   * Merge two state in execution engine.
+   *
+   * @param state current state
+   * @param rhs right-hand-side state to be merged
+   */
+  void combineState(State state, State rhs);
+
+  /**
+   * Calculate output value from final state
+   *
+   * @param state final state
+   * @param resultValue used to collect output data points
+   */
+  void outputFinal(State state, ResultValue resultValue);
+
+  /**
+   * This method is optional Remove partial state from current state. 
Implement this method to
+   * enable sliding window feature.
+   *
+   * @param state current state
+   * @param removed state to be removed
+   */
+  default void removeState(State state, State removed) {
+    throw new UnsupportedOperationException(getClass().getName());
+  }
+
+  /** This method is mainly used to release the resources used in the 
SQLFunction. */
+  default void beforeDestroy() {
+    // do nothing
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
index 125fa756302..17ded3b0a5c 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -20,6 +20,8 @@
 package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
 
 import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+import org.apache.iotdb.commons.udf.utils.TableUDFUtils;
+import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
 import 
org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator;
@@ -35,13 +37,18 @@ import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggr
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
 
 import org.apache.tsfile.enums.TSDataType;
 
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkState;
 import static 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction.FIRST_BY;
@@ -60,7 +67,7 @@ public class AccumulatorFactory {
       String timeColumnName) {
     if (aggregationType == TAggregationType.UDAF) {
       // If UDAF accumulator receives raw input, it needs to check input's 
attribute
-      throw new UnsupportedOperationException();
+      return createUDAFAccumulator(functionName, inputDataTypes, 
inputAttributes);
     } else if ((LAST_BY.getFunctionName().equals(functionName)
             || FIRST_BY.getFunctionName().equals(functionName))
         && inputExpressions.size() > 1) {
@@ -99,13 +106,39 @@ public class AccumulatorFactory {
       boolean ascending) {
     if (aggregationType == TAggregationType.UDAF) {
       // If UDAF accumulator receives raw input, it needs to check input's 
attribute
-      throw new UnsupportedOperationException();
+      return createGroupedUDAFAccumulator(functionName, inputDataTypes, 
inputAttributes);
     } else {
       return createBuiltinGroupedAccumulator(
           aggregationType, inputDataTypes, inputExpressions, inputAttributes, 
ascending);
     }
   }
 
+  private static TableAccumulator createUDAFAccumulator(
+      String functionName, List<TSDataType> inputDataTypes, Map<String, 
String> inputAttributes) {
+    AggregateFunction aggregateFunction = 
TableUDFUtils.getAggregateFunction(functionName);
+    FunctionParameters functionParameters =
+        new FunctionParameters(
+            UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), 
inputAttributes);
+    AggregateFunctionConfig config = new AggregateFunctionConfig();
+    aggregateFunction.beforeStart(functionParameters, config);
+    return new UserDefinedAggregateFunctionAccumulator(aggregateFunction);
+  }
+
+  private static GroupedAccumulator createGroupedUDAFAccumulator(
+      String functionName, List<TSDataType> inputDataTypes, Map<String, 
String> inputAttributes) {
+    AggregateFunction aggregateFunction = 
TableUDFUtils.getAggregateFunction(functionName);
+    FunctionParameters functionParameters =
+        new FunctionParameters(
+            UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), 
inputAttributes);
+    AggregateFunctionConfig config = new AggregateFunctionConfig();
+    aggregateFunction.beforeStart(functionParameters, config);
+    return new GroupedUserDefinedAggregateAccumulator(
+        aggregateFunction,
+        inputDataTypes.stream()
+            .map(UDFDataTypeTransformer::transformTSDataTypeToReadType)
+            .collect(Collectors.toList()));
+  }
+
   private static GroupedAccumulator createBuiltinGroupedAccumulator(
       TAggregationType aggregationType,
       List<TSDataType> inputDataTypes,
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
new file mode 100644
index 00000000000..9c3f0c78168
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.file.metadata.statistics.Statistics;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class UserDefinedAggregateFunctionAccumulator implements 
TableAccumulator {
+
+  private static final long INSTANCE_SIZE =
+      
RamUsageEstimator.shallowSizeOfInstance(UserDefinedAggregateFunctionAccumulator.class);
+  private final AggregateFunction aggregateFunction;
+  private final State state;
+
+  public UserDefinedAggregateFunctionAccumulator(AggregateFunction 
aggregateFunction) {
+    this.aggregateFunction = aggregateFunction;
+    this.state = aggregateFunction.createState();
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public TableAccumulator copy() {
+    return new UserDefinedAggregateFunctionAccumulator(aggregateFunction);
+  }
+
+  @Override
+  public void addInput(Column[] arguments) {
+    aggregateFunction.addInput(state, arguments);
+  }
+
+  @Override
+  public void addIntermediate(Column argument) {
+    checkArgument(
+        argument instanceof BinaryColumn
+            || (argument instanceof RunLengthEncodedColumn
+                && ((RunLengthEncodedColumn) argument).getValue() instanceof 
BinaryColumn),
+        "intermediate input and output of UDAF should be BinaryColumn");
+    State otherState = aggregateFunction.createState();
+    Binary otherStateBinary = argument.getBinary(0);
+    otherState.deserialize(otherStateBinary.getValues());
+
+    aggregateFunction.combineState(state, otherState);
+  }
+
+  @Override
+  public void evaluateIntermediate(ColumnBuilder columnBuilder) {
+    checkArgument(
+        columnBuilder instanceof BinaryColumnBuilder,
+        "intermediate input and output of UDAF should be BinaryColumn");
+    byte[] bytes = state.serialize();
+    columnBuilder.writeBinary(new Binary(bytes));
+  }
+
+  @Override
+  public void evaluateFinal(ColumnBuilder columnBuilder) {
+    ResultValue resultValue = new ResultValue(columnBuilder);
+    aggregateFunction.outputFinal(state, resultValue);
+  }
+
+  @Override
+  public boolean hasFinalResult() {
+    // TODO
+    return false;
+  }
+
+  @Override
+  public void addStatistics(Statistics[] statistics) {
+    // TODO
+  }
+
+  @Override
+  public void reset() {
+    state.reset();
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
new file mode 100644
index 00000000000..659bbd0a01a
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
+
+import org.apache.iotdb.commons.udf.access.RecordIterator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.ObjectBigArray;
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.read.common.type.Type;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class GroupedUserDefinedAggregateAccumulator implements 
GroupedAccumulator {
+
+  private static final long INSTANCE_SIZE =
+      
RamUsageEstimator.shallowSizeOfInstance(GroupedUserDefinedAggregateAccumulator.class);
+  private final AggregateFunction aggregateFunction;
+  private final ObjectBigArray<State> stateArray;
+  private final List<Type> inputDataTypes;
+
+  public GroupedUserDefinedAggregateAccumulator(
+      AggregateFunction aggregateFunction, List<Type> inputDataTypes) {
+    this.aggregateFunction = aggregateFunction;
+    this.stateArray = new ObjectBigArray<>();
+    this.inputDataTypes = inputDataTypes;
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public void setGroupCount(long groupCount) {
+    stateArray.ensureCapacity(groupCount);
+  }
+
+  private State getOrCreateState(int groupId) {
+    State state = stateArray.get(groupId);
+    if (state == null) {
+      state = aggregateFunction.createState();
+      stateArray.set(groupId, state);
+    }
+    return state;
+  }
+
+  @Override
+  public void addInput(int[] groupIds, Column[] arguments) {
+    RecordIterator iterator =
+        new RecordIterator(
+            Arrays.asList(arguments), inputDataTypes, 
arguments[0].getPositionCount());
+    int index = 0;
+    while (iterator.hasNext()) {
+      int groupId = groupIds[index++];
+      State state = getOrCreateState(groupId);
+      if (state == null) {
+        state = aggregateFunction.createState();
+        stateArray.set(groupId, state);
+      }
+      aggregateFunction.addInput(state, iterator.next());
+    }
+  }
+
+  @Override
+  public void addIntermediate(int[] groupIds, Column argument) {
+    checkArgument(
+        argument instanceof BinaryColumn
+            || (argument instanceof RunLengthEncodedColumn
+                && ((RunLengthEncodedColumn) argument).getValue() instanceof 
BinaryColumn),
+        "intermediate input and output of UDAF should be BinaryColumn");
+
+    for (int i = 0; i < groupIds.length; i++) {
+      if (!argument.isNull(i)) {
+        State otherState = aggregateFunction.createState();
+        Binary otherStateBinary = argument.getBinary(i);
+        otherState.deserialize(otherStateBinary.getValues());
+        aggregateFunction.combineState(getOrCreateState(groupIds[i]), 
otherState);
+      }
+    }
+  }
+
+  @Override
+  public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
+    checkArgument(
+        columnBuilder instanceof BinaryColumnBuilder,
+        "intermediate input and output of UDAF should be BinaryColumn");
+    if (stateArray.get(groupId) == null) {
+      columnBuilder.writeBinary(new Binary(new byte[0]));
+    } else {
+      byte[] bytes = stateArray.get(groupId).serialize();
+      columnBuilder.writeBinary(new Binary(bytes));
+    }
+  }
+
+  @Override
+  public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
+    ResultValue resultValue = new ResultValue(columnBuilder);
+    aggregateFunction.outputFinal(getOrCreateState(groupId), resultValue);
+  }
+
+  @Override
+  public void prepareFinal() {}
+
+  @Override
+  public void reset() {
+    stateArray.reset();
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java
new file mode 100644
index 00000000000..e4e70e49552
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/StateBigArray.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array;
+
+import org.apache.iotdb.udf.api.State;
+
+import org.apache.tsfile.block.column.Column;
+
+import java.util.function.Supplier;
+
+public class StateBigArray {
+
+  private final ObjectBigArray<State> array;
+  private final Supplier<State> stateSupplier;
+
+  public StateBigArray(Supplier<State> stateSupplier) {
+    this.array = new ObjectBigArray<>();
+    this.stateSupplier = stateSupplier;
+  }
+
+  public State get(long index) {
+    return array.get(index);
+  }
+
+  public void set(long index, State value) {
+    array.set(index, value);
+  }
+
+  public void update(long index, Column[] arguments) {
+    State state = array.get(index);
+    if (state == null) {
+      state = stateSupplier.get();
+      array.set(index, state);
+    }
+  }
+
+  public void ensureCapacity(long length) {
+    array.ensureCapacity(length);
+  }
+
+  public void reset() {
+    array.reset();
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java
index f786bb1ceff..f8cec46dfd1 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java
@@ -20,6 +20,7 @@
 package org.apache.iotdb.db.queryengine.plan.relational.analyzer;
 
 import 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction;
+import org.apache.iotdb.commons.udf.utils.TableUDFUtils;
 import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultExpressionTraversalVisitor;
 import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
@@ -92,8 +93,8 @@ public final class ExpressionTreeUtils {
   }
 
   static boolean isAggregationFunction(String functionName) {
-    // TODO consider UDAF
     return TableBuiltinAggregationFunction.getBuiltInAggregateFunctionName()
-        .contains(functionName.toLowerCase());
+            .contains(functionName.toLowerCase())
+        || TableUDFUtils.isAggregateFunction(functionName);
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
index 26e68ab5cad..e89931c8d88 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
@@ -49,8 +49,10 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundExceptio
 import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature;
 import org.apache.iotdb.db.schemaengine.table.DataNodeTableCache;
 import org.apache.iotdb.db.utils.constant.SqlConstant;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
 import org.apache.iotdb.udf.api.customizer.config.ScalarFunctionConfig;
 import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
 import org.apache.iotdb.udf.api.relational.ScalarFunction;
 
 import org.apache.tsfile.file.metadata.IDeviceID;
@@ -650,10 +652,26 @@ public class TableMetadataImpl implements Metadata {
       } finally {
         scalarFunction.beforeDestroy();
       }
+    } else if (TableUDFUtils.isAggregateFunction(functionName)) {
+      AggregateFunction aggregateFunction = 
TableUDFUtils.getAggregateFunction(functionName);
+      FunctionParameters functionParameters =
+          new FunctionParameters(
+              argumentTypes.stream()
+                  .map(UDFDataTypeTransformer::transformReadTypeToUDFDataType)
+                  .collect(Collectors.toList()),
+              Collections.emptyMap());
+      try {
+        aggregateFunction.validate(functionParameters);
+        AggregateFunctionConfig config = new AggregateFunctionConfig();
+        aggregateFunction.beforeStart(functionParameters, config);
+        return 
UDFDataTypeTransformer.transformUDFDataTypeToReadType(config.getOutputDataType());
+      } catch (Exception e) {
+        throw new SemanticException("Invalid function parameters: " + 
e.getMessage());
+      } finally {
+        aggregateFunction.beforeDestroy();
+      }
     }
 
-    // TODO UDAF
-
     throw new SemanticException("Unknown function: " + functionName);
   }
 
@@ -661,7 +679,8 @@ public class TableMetadataImpl implements Metadata {
   public boolean isAggregationFunction(
       final SessionInfo session, final String functionName, final 
AccessControl accessControl) {
     return TableBuiltinAggregationFunction.getBuiltInAggregateFunctionName()
-        .contains(functionName.toLowerCase(Locale.ENGLISH));
+            .contains(functionName.toLowerCase(Locale.ENGLISH))
+        || TableUDFUtils.isAggregateFunction(functionName);
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
index 844744dbe79..8fff6f2eb23 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
@@ -62,8 +62,6 @@ public class Util {
               resolvedFunction.getSignature().getArgumentTypes());
       Symbol intermediateSymbol =
           symbolAllocator.newSymbol(resolvedFunction.getSignature().getName(), 
intermediateType);
-      // TODO put symbol and its type to TypeProvide or later process: add all 
map contents of
-      // SymbolAllocator to the TypeProvider
       checkState(
           !originalAggregation.getOrderingScheme().isPresent(),
           "Aggregate with ORDER BY does not support partial aggregation");
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
index 49db27110c7..4cd046dfeb6 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
@@ -31,6 +31,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import static org.apache.tsfile.read.common.type.BlobType.BLOB;
 import static org.apache.tsfile.read.common.type.DoubleType.DOUBLE;
 import static org.apache.tsfile.read.common.type.LongType.INT64;
 
@@ -103,7 +104,8 @@ public enum TableBuiltinAggregationFunction {
       case "min":
         return originalArgumentTypes.get(0);
       default:
-        throw new IllegalArgumentException("Invalid Aggregation function: " + 
name);
+        // default is UDAF
+        return BLOB;
     }
   }
 
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java
index 1cae52ea5b9..c7ca2301953 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java
@@ -22,11 +22,15 @@ import org.apache.iotdb.udf.api.type.Type;
 
 import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.read.common.type.BinaryType;
+import org.apache.tsfile.read.common.type.BlobType;
 import org.apache.tsfile.read.common.type.BooleanType;
+import org.apache.tsfile.read.common.type.DateType;
 import org.apache.tsfile.read.common.type.DoubleType;
 import org.apache.tsfile.read.common.type.FloatType;
 import org.apache.tsfile.read.common.type.IntType;
 import org.apache.tsfile.read.common.type.LongType;
+import org.apache.tsfile.read.common.type.StringType;
+import org.apache.tsfile.read.common.type.TimestampType;
 
 import java.util.List;
 import java.util.stream.Collectors;
@@ -121,19 +125,54 @@ public class UDFDataTypeTransformer {
       case BOOLEAN:
         return BooleanType.BOOLEAN;
       case INT32:
-      case DATE:
         return IntType.INT32;
+      case DATE:
+        return DateType.DATE;
       case INT64:
-      case TIMESTAMP:
         return LongType.INT64;
+      case TIMESTAMP:
+        return TimestampType.TIMESTAMP;
       case FLOAT:
         return FloatType.FLOAT;
       case DOUBLE:
         return DoubleType.DOUBLE;
       case TEXT:
+        return BinaryType.TEXT;
       case BLOB:
+        return BlobType.BLOB;
       case STRING:
+        return StringType.STRING;
+      default:
+        throw new IllegalArgumentException("Invalid input: " + type);
+    }
+  }
+
+  public static org.apache.tsfile.read.common.type.Type 
transformTSDataTypeToReadType(
+      TSDataType type) {
+    if (type == null) {
+      return null;
+    }
+    switch (type) {
+      case BOOLEAN:
+        return BooleanType.BOOLEAN;
+      case INT32:
+        return IntType.INT32;
+      case DATE:
+        return DateType.DATE;
+      case INT64:
+        return LongType.INT64;
+      case TIMESTAMP:
+        return TimestampType.TIMESTAMP;
+      case FLOAT:
+        return FloatType.FLOAT;
+      case DOUBLE:
+        return DoubleType.DOUBLE;
+      case TEXT:
         return BinaryType.TEXT;
+      case BLOB:
+        return BlobType.BLOB;
+      case STRING:
+        return StringType.STRING;
       default:
         throw new IllegalArgumentException("Invalid input: " + type);
     }


Reply via email to