This is an automated email from the ASF dual-hosted git repository. Wei-hao-Li pushed a commit to branch IoTDBLocal in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 0d161fe98abcb87d9d526fc772eb9d326bec32af Author: Weihao Li <[email protected]> AuthorDate: Fri Jun 26 00:30:21 2026 +0800 fix beforeStart of UDSF+UDAF Signed-off-by: Weihao Li <[email protected]> --- .../relational/it/db/it/udf/IoTDBLocalIT.java | 6 -- .../relational/aggregation/AccumulatorFactory.java | 17 ------ .../UserDefinedAggregateFunctionAccumulator.java | 33 ++++++++++- .../GroupedUserDefinedAggregateAccumulator.java | 20 +++++++ .../udf/UserDefineScalarFunctionTransformer.java | 69 +++++++++++----------- 5 files changed, 87 insertions(+), 58 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java index d6fb2339ae2..c50d360ba66 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBLocalIT.java @@ -78,12 +78,6 @@ public class IoTDBLocalIT { "CLEAR ATTRIBUTE CACHE", }; - public static void main(String[] args) { - for (String sql : SETUP_SQLS) { - System.out.println(sql + ";"); - } - } - @BeforeClass public static void setUp() throws Exception { EnvFactory.getEnv().getConfig().getCommonConfig().setEnforceStrongPassword(false); diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java index d2b40c742d5..f9d35cf67f5 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -67,7 +67,6 @@ import org.apache.iotdb.commons.queryengine.plan.udf.TableUDFUtils; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.udf.api.IoTDBLocal; import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments; -import org.apache.iotdb.udf.api.exception.UDFException; import org.apache.iotdb.udf.api.relational.AggregateFunction; import com.google.common.collect.ImmutableList; @@ -287,14 +286,6 @@ public class AccumulatorFactory { FunctionArguments functionArguments = new FunctionArguments( UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes); - try { - aggregateFunction.beforeStart(functionArguments, ioTDBLocal); - } catch (UDFException e) { - throw new RuntimeException( - "Error occurs when starting user-defined aggregate function " - + aggregateFunction.getClass().getName(), - e); - } return new UserDefinedAggregateFunctionAccumulator( aggregateFunction.analyze(functionArguments), aggregateFunction, @@ -313,14 +304,6 @@ public class AccumulatorFactory { FunctionArguments functionArguments = new FunctionArguments( UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes); - try { - aggregateFunction.beforeStart(functionArguments, ioTDBLocal); - } catch (UDFException e) { - throw new RuntimeException( - "Error occurs when starting user-defined aggregate function " - + aggregateFunction.getClass().getName(), - e); - } return new GroupedUserDefinedAggregateAccumulator( aggregateFunction, functionArguments, diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java index 0e1d38f54e4..7bedcbbb030 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java @@ -24,6 +24,7 @@ import org.apache.iotdb.udf.api.IoTDBLocal; import org.apache.iotdb.udf.api.State; import org.apache.iotdb.udf.api.customizer.analysis.AggregateFunctionAnalysis; import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments; +import org.apache.iotdb.udf.api.exception.UDFException; import org.apache.iotdb.udf.api.relational.AggregateFunction; import org.apache.iotdb.udf.api.utils.ResultValue; @@ -52,6 +53,7 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator private final List<Type> inputDataTypes; private final State state; private final IoTDBLocal ioTDBLocal; + private boolean init; public UserDefinedAggregateFunctionAccumulator( AggregateFunctionAnalysis analysis, @@ -59,6 +61,16 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator FunctionArguments functionArguments, List<Type> inputDataTypes, IoTDBLocal ioTDBLocal) { + this(analysis, aggregateFunction, functionArguments, inputDataTypes, ioTDBLocal, false); + } + + private UserDefinedAggregateFunctionAccumulator( + AggregateFunctionAnalysis analysis, + AggregateFunction aggregateFunction, + FunctionArguments functionArguments, + List<Type> inputDataTypes, + IoTDBLocal ioTDBLocal, + boolean init) { checkArgument(ioTDBLocal != null, "IoTDBLocal must not be null for UDAF"); this.analysis = analysis; this.aggregateFunction = aggregateFunction; @@ -66,6 +78,22 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator this.inputDataTypes = inputDataTypes; this.state = aggregateFunction.createState(); this.ioTDBLocal = ioTDBLocal; + this.init = init; + } + + private void initIfNeeded() { + if (init) { + return; + } + init = true; + try { + aggregateFunction.beforeStart(functionArguments, ioTDBLocal); + } catch (UDFException e) { + throw new RuntimeException( + "Error occurs when starting user-defined aggregate function " + + aggregateFunction.getClass().getName(), + e); + } } @Override @@ -76,11 +104,12 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator @Override public TableAccumulator copy() { return new UserDefinedAggregateFunctionAccumulator( - analysis, aggregateFunction, functionArguments, inputDataTypes, ioTDBLocal); + analysis, aggregateFunction, functionArguments, inputDataTypes, ioTDBLocal, true); } @Override public void addInput(Column[] arguments, AggregationMask mask) { + initIfNeeded(); RecordIterator iterator = mask.isSelectAll() ? new RecordIterator( @@ -93,6 +122,7 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator @Override public void addIntermediate(Column argument) { + initIfNeeded(); checkArgument( argument instanceof BinaryColumn || (argument instanceof RunLengthEncodedColumn @@ -118,6 +148,7 @@ public class UserDefinedAggregateFunctionAccumulator implements TableAccumulator @Override public void evaluateFinal(ColumnBuilder columnBuilder) { + initIfNeeded(); ResultValue resultValue = new ResultValue(columnBuilder); aggregateFunction.outputFinal(state, resultValue, ioTDBLocal); } diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java index 5cffef01105..a530ce4f7de 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java @@ -27,6 +27,7 @@ import org.apache.iotdb.calc.i18n.CalcMessages; import org.apache.iotdb.udf.api.IoTDBLocal; import org.apache.iotdb.udf.api.State; import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments; +import org.apache.iotdb.udf.api.exception.UDFException; import org.apache.iotdb.udf.api.relational.AggregateFunction; import org.apache.iotdb.udf.api.utils.ResultValue; @@ -53,6 +54,7 @@ public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulato private final ObjectBigArray<State> stateArray; private final List<Type> inputDataTypes; private final IoTDBLocal ioTDBLocal; + private boolean init = false; public GroupedUserDefinedAggregateAccumulator( AggregateFunction aggregateFunction, @@ -67,6 +69,21 @@ public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulato this.ioTDBLocal = ioTDBLocal; } + private void initIfNeeded() { + if (init) { + return; + } + init = true; + try { + aggregateFunction.beforeStart(functionArguments, ioTDBLocal); + } catch (UDFException e) { + throw new RuntimeException( + "Error occurs when starting user-defined aggregate function " + + aggregateFunction.getClass().getName(), + e); + } + } + @Override public long getEstimatedSize() { return INSTANCE_SIZE; @@ -88,6 +105,7 @@ public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulato @Override public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + initIfNeeded(); RecordIterator iterator = mask.isSelectAll() ? new RecordIterator( @@ -115,6 +133,7 @@ public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulato @Override public void addIntermediate(int[] groupIds, Column argument) { + initIfNeeded(); checkArgument( argument instanceof BinaryColumn || (argument instanceof RunLengthEncodedColumn @@ -146,6 +165,7 @@ public class GroupedUserDefinedAggregateAccumulator implements GroupedAccumulato @Override public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + initIfNeeded(); ResultValue resultValue = new ResultValue(columnBuilder); aggregateFunction.outputFinal(getOrCreateState(groupId), resultValue, ioTDBLocal); } diff --git a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java index 07d92d967f8..6955dd56663 100644 --- a/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java +++ b/iotdb-core/calc-commons/src/main/java/org/apache/iotdb/calc/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java @@ -37,11 +37,15 @@ import org.apache.tsfile.read.common.type.Type; import java.util.List; import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; + public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer { private final ScalarFunction scalarFunction; + private final FunctionArguments parameters; private final List<Type> inputTypes; private final IoTDBLocal ioTDBLocal; + private boolean init = false; public UserDefineScalarFunctionTransformer( Type returnType, @@ -51,15 +55,35 @@ public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer ColumnTransformerBuilder.Context context) { super(returnType, childrenTransformers); this.scalarFunction = scalarFunction; + this.parameters = parameters; this.ioTDBLocal = createIoTDBLocal(context); this.inputTypes = childrenTransformers.stream().map(ColumnTransformer::getType).collect(Collectors.toList()); + } + + private static IoTDBLocal createIoTDBLocal(ColumnTransformerBuilder.Context context) { + IoTDBLocalFactory factory = context.getIoTDBLocalFactory(); + String fragmentInstanceId = context.getFragmentInstanceId(); + String outerGlobalQueryId = context.getOuterGlobalQueryId(); + long outerLocalQueryId = context.getOuterLocalQueryId(); + checkArgument(factory != null, "IoTDBLocalFactory must not be null for UDF execution"); + checkArgument( + fragmentInstanceId != null, "fragmentInstanceId must not be null for UDF execution"); + checkArgument( + outerGlobalQueryId != null, "outerGlobalQueryId must not be null for UDF execution"); + checkArgument( + outerLocalQueryId >= 0, "outerLocalQueryId must not be negative for UDF execution"); + return factory.create( + context.getSessionInfo(), fragmentInstanceId, outerLocalQueryId, outerGlobalQueryId); + } + + private void initIfNeeded() { + if (init) { + return; + } + init = true; try { - if (ioTDBLocal != null) { - scalarFunction.beforeStart(parameters, ioTDBLocal); - } else { - scalarFunction.beforeStart(parameters); - } + scalarFunction.beforeStart(parameters, ioTDBLocal); } catch (UDFException e) { throw new RuntimeException( "Error occurs when starting user-defined scalar function " @@ -68,31 +92,14 @@ public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer } } - private static IoTDBLocal createIoTDBLocal(ColumnTransformerBuilder.Context context) { - IoTDBLocalFactory factory = context.getIoTDBLocalFactory(); - if (factory == null - || context.getFragmentInstanceId() == null - || context.getOuterGlobalQueryId() == null - || context.getOuterLocalQueryId() < 0) { - return null; - } - return factory.create( - context.getSessionInfo(), - context.getFragmentInstanceId(), - context.getOuterLocalQueryId(), - context.getOuterGlobalQueryId()); - } - @Override protected void doTransform( List<Column> childrenColumns, ColumnBuilder builder, int positionCount) { + initIfNeeded(); RecordIterator iterator = new RecordIterator(childrenColumns, inputTypes, positionCount); while (iterator.hasNext()) { try { - Object result = - ioTDBLocal != null - ? scalarFunction.evaluate(iterator.next(), ioTDBLocal) - : scalarFunction.evaluate(iterator.next()); + Object result = scalarFunction.evaluate(iterator.next(), ioTDBLocal); if (result == null) { builder.appendNull(); } else { @@ -110,6 +117,7 @@ public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer @Override protected void doTransform( List<Column> childrenColumns, ColumnBuilder builder, int positionCount, boolean[] selection) { + initIfNeeded(); RecordIterator iterator = new RecordIterator(childrenColumns, inputTypes, positionCount); int i = 0; while (iterator.hasNext()) { @@ -119,10 +127,7 @@ public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer builder.appendNull(); continue; } - Object result = - ioTDBLocal != null - ? scalarFunction.evaluate(input, ioTDBLocal) - : scalarFunction.evaluate(input); + Object result = scalarFunction.evaluate(input, ioTDBLocal); if (result == null) { builder.appendNull(); } else { @@ -140,12 +145,8 @@ public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer @Override public void close() { super.close(); - if (ioTDBLocal != null) { - ioTDBLocal.close(); - scalarFunction.beforeDestroy(ioTDBLocal); - } else { - scalarFunction.beforeDestroy(); - } + scalarFunction.beforeDestroy(ioTDBLocal); + ioTDBLocal.close(); } @Override
