This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch beyyes/agg_plan_device_cross_region
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 7801327e9f50ef81d4e9cb214be5244fabb1cb37
Author: Beyyes <[email protected]>
AuthorDate: Thu Feb 29 21:53:38 2024 +0800

    fix
---
 .../execution/aggregation/Accumulator.java         |  5 +-
 .../aggregation/MaxMinByBaseAccumulator.java       |  5 ++
 .../execution/aggregation/MaxTimeAccumulator.java  |  5 ++
 .../execution/aggregation/SumAccumulator.java      |  5 ++
 .../process/AggregationMergeSortOperator.java      | 81 +++++++++-------------
 .../db/queryengine/plan/analyze/Analysis.java      | 10 +++
 .../plan/planner/distribution/SourceRewriter.java  | 63 +++++++----------
 .../node/process/AggregationMergeSortNode.java     | 18 ++++-
 8 files changed, 99 insertions(+), 93 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Accumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Accumulator.java
index 7d765aa5857..871dd0e17a9 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Accumulator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Accumulator.java
@@ -86,8 +86,5 @@ public interface Accumulator {
 
   TSDataType getFinalType();
 
-  default int getPartialResultSize() {
-    throw new UnsupportedOperationException(
-        "This type of accumulator does not support getPartialResultSize!");
-  }
+  int getPartialResultSize();
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxMinByBaseAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxMinByBaseAccumulator.java
index a4ea49be634..2bd603d01ac 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxMinByBaseAccumulator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxMinByBaseAccumulator.java
@@ -158,6 +158,11 @@ public abstract class MaxMinByBaseAccumulator implements 
Accumulator {
     return xDataType;
   }
 
+  @Override
+  public int getPartialResultSize() {
+    return 1;
+  }
+
   private void addIntInput(Column[] column, BitMap bitMap) {
     int count = column[0].getPositionCount();
     for (int i = 0; i < count; i++) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxTimeAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxTimeAccumulator.java
index 81634bbddfd..edf63dab360 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxTimeAccumulator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxTimeAccumulator.java
@@ -118,6 +118,11 @@ public class MaxTimeAccumulator implements Accumulator {
     return TSDataType.INT64;
   }
 
+  @Override
+  public int getPartialResultSize() {
+    return 1;
+  }
+
   protected void updateMaxTime(long curTime) {
     initResult = true;
     maxTime = Math.max(maxTime, curTime);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/SumAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/SumAccumulator.java
index e1d99ef7f76..d25f05242f6 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/SumAccumulator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/SumAccumulator.java
@@ -148,6 +148,11 @@ public class SumAccumulator implements Accumulator {
     return TSDataType.DOUBLE;
   }
 
+  @Override
+  public int getPartialResultSize() {
+    return 1;
+  }
+
   private void addIntInput(Column[] column, BitMap bitMap) {
     int count = column[0].getPositionCount();
     for (int i = 0; i < count; i++) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AggregationMergeSortOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AggregationMergeSortOperator.java
index 331784d1451..4b0d959c882 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AggregationMergeSortOperator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AggregationMergeSortOperator.java
@@ -20,31 +20,26 @@
 package org.apache.iotdb.db.queryengine.execution.operator.process;
 
 import org.apache.iotdb.db.queryengine.execution.aggregation.Accumulator;
-import org.apache.iotdb.db.queryengine.execution.aggregation.Aggregator;
 import org.apache.iotdb.db.queryengine.execution.operator.Operator;
 import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext;
 import 
org.apache.iotdb.db.queryengine.execution.operator.process.join.merge.TimeComparator;
+import org.apache.iotdb.tsfile.common.conf.TSFileDescriptor;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
-import org.apache.iotdb.tsfile.read.common.TimeRange;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
 import org.apache.iotdb.tsfile.read.common.block.column.Column;
 import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder;
 import org.apache.iotdb.tsfile.read.common.block.column.TimeColumnBuilder;
+import org.apache.iotdb.tsfile.utils.Binary;
 
 import com.google.common.util.concurrent.ListenableFuture;
-import org.apache.iotdb.tsfile.utils.Binary;
-import org.checkerframework.checker.units.qual.C;
 
 import java.util.ArrayList;
 import java.util.Comparator;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.TimeUnit;
 
 import static com.google.common.util.concurrent.Futures.successfulAsList;
-import static 
org.apache.iotdb.db.queryengine.execution.aggregation.AccumulatorFactory.createAccumulator;
 
 public class AggregationMergeSortOperator extends AbstractConsumeAllOperator {
 
@@ -57,21 +52,13 @@ public class AggregationMergeSortOperator extends 
AbstractConsumeAllOperator {
 
   private boolean finished;
 
-  private Map<String, List<Aggregator>> aggMap;
-
   private final TimeComparator timeComparator;
 
   private final Comparator<Binary> deviceComparator;
 
-  private boolean currentFinished;
-
-  private Binary currentDevice;
-
   private long currentTime;
 
-  private int[] readIndex;
-
-  List<Integer> newAggregationIdx;
+  private final int[] readIndex;
 
   public AggregationMergeSortOperator(
       OperatorContext operatorContext,
@@ -87,7 +74,7 @@ public class AggregationMergeSortOperator extends 
AbstractConsumeAllOperator {
     this.accumulators = accumulators;
     this.timeComparator = timeComparator;
     this.deviceComparator = deviceComparator;
-    readIndex = new int[inputTsBlocks.length];
+    this.readIndex = new int[inputTsBlocks.length];
   }
 
   @Override
@@ -95,6 +82,7 @@ public class AggregationMergeSortOperator extends 
AbstractConsumeAllOperator {
     long startTime = System.nanoTime();
     long maxRuntime = 
operatorContext.getMaxRunTime().roundTo(TimeUnit.NANOSECONDS);
 
+    // init all element in inputTsBlocks
     if (!prepareInput()) {
       return null;
     }
@@ -104,73 +92,68 @@ public class AggregationMergeSortOperator extends 
AbstractConsumeAllOperator {
     ColumnBuilder[] valueColumnBuilders = 
tsBlockBuilder.getValueColumnBuilders();
 
     while (true) {
-      currentDevice = null;
+      Binary currentDevice = null;
+      boolean hashChildFinished = false;
 
       for (int idx = 0; idx < inputTsBlocks.length; idx++) {
         TsBlock tsBlock = inputTsBlocks[idx];
-        if (!noMoreTsBlocks[idx] && tsBlock == null) {
-          return null;
+        if (noMoreTsBlocks[idx]) {
+          continue;
         }
 
-        if (readIndex[idx] >= tsBlock.getPositionCount()) {
+        if (tsBlock == null || readIndex[idx] >= tsBlock.getPositionCount()) {
+          hashChildFinished = true;
           inputTsBlocks[idx] = null;
+          readIndex[idx] = 0;
+          currentDevice = null;
+          break;
         }
 
+        // if group by time, columnIndex may be greater than 0
         Binary device = tsBlock.getColumn(0).getBinary(readIndex[idx]);
-
-        if (currentDevice == null) {
+        if (currentDevice == null || deviceComparator.compare(device, 
currentDevice) < 0) {
           currentDevice = device;
-        } else {
-          if (deviceComparator.compare(device, currentDevice) < 0) {
-            currentDevice = device;
-          }
         }
       }
 
-      if (currentDevice == null) {
+      if (hashChildFinished) {
         break;
       }
 
       for (int idx = 0; idx < inputTsBlocks.length; idx++) {
         TsBlock tsBlock = inputTsBlocks[idx];
-        if (tsBlock == null) {
+        if (noMoreTsBlocks[idx]) {
           continue;
         }
 
-        if (readIndex[idx] >= tsBlock.getPositionCount()) {
-          inputTsBlocks[idx] = null;
-        }
-
         Binary device = tsBlock.getColumn(0).getBinary(readIndex[idx]);
         if (device.equals(currentDevice)) {
           currentTime = tsBlock.getTimeColumn().getLong(readIndex[idx]);
           int cnt = 1;
-          for (int i = 0; i < accumulators.size(); i++) {
-            Accumulator accumulator = accumulators.get(i);
+          for (Accumulator accumulator : accumulators) {
             if (accumulator.getPartialResultSize() == 2) {
-              Column[] columns = new Column[2];
-              columns[0] = tsBlock.getColumn(cnt++);
-              columns[1] = tsBlock.getColumn(cnt++);
-              accumulator.addIntermediate(columns);
+              // TODO only has group by, use subColumn
+              accumulator.addIntermediate(
+                  new Column[] {
+                    tsBlock.getColumn(cnt++).subColumn(readIndex[idx]),
+                    tsBlock.getColumn(cnt++).subColumn(readIndex[idx])
+                  });
             } else {
-              Column[] columns = new Column[1];
-              columns[0] = tsBlock.getColumn(cnt++);
-              accumulator.addIntermediate(columns);
+              accumulator.addIntermediate(
+                  new Column[] 
{tsBlock.getColumn(cnt++).subColumn(readIndex[idx])});
             }
           }
-          readIndex[idx] ++;
-
-          accumulators.forEach(Accumulator::reset);
+          readIndex[idx]++;
         }
       }
 
       timeBuilder.writeLong(currentTime);
+      valueColumnBuilders[0].writeBinary(currentDevice);
       for (int i = 1; i < dataTypes.size(); i++) {
-        accumulators.get(i-1).outputFinal(valueColumnBuilders[i]);
+        accumulators.get(i - 1).outputFinal(valueColumnBuilders[i]);
       }
       tsBlockBuilder.declarePosition();
-
-      currentDevice = null;
+      accumulators.forEach(Accumulator::reset);
 
       if (System.nanoTime() - startTime > maxRuntime || 
tsBlockBuilder.isFull()) {
         break;
@@ -267,7 +250,7 @@ public class AggregationMergeSortOperator extends 
AbstractConsumeAllOperator {
 
   @Override
   public long calculateMaxReturnSize() {
-    return 0;
+    return 
TSFileDescriptor.getInstance().getConfig().getMaxTsBlockSizeInBytes();
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
index 3014cdec431..56dd5561b3c 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
@@ -176,6 +176,8 @@ public class Analysis {
 
   private boolean existDeviceCrossRegion;
 
+  private boolean useAggMergeSort;
+
   
/////////////////////////////////////////////////////////////////////////////////////////////////
   // Query Common Analysis (above DeviceView)
   
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -666,6 +668,14 @@ public class Analysis {
     this.existDeviceCrossRegion = true;
   }
 
+  public boolean isUseAggMergeSort() {
+    return this.useAggMergeSort;
+  }
+
+  public void setUseAggMergeSort() {
+    this.useAggMergeSort = true;
+  }
+
   public DeviceViewIntoPathDescriptor getDeviceViewIntoPathDescriptor() {
     return deviceViewIntoPathDescriptor;
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java
index 022345dd250..e88b1b431ed 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java
@@ -93,6 +93,7 @@ import static 
org.apache.iotdb.commons.partition.DataPartition.NOT_ASSIGNED;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.ExpressionTypeAnalyzer.analyzeExpression;
 import static 
org.apache.iotdb.db.queryengine.plan.planner.LogicalPlanBuilder.updateTypeProviderByPartialAggregation;
 import static org.apache.iotdb.db.utils.constant.SqlConstant.AVG;
+import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT_IF;
 import static org.apache.iotdb.db.utils.constant.SqlConstant.FIRST_VALUE;
 import static org.apache.iotdb.db.utils.constant.SqlConstant.TIME_DURATION;
 
@@ -197,11 +198,18 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
                   analysis.getPartitionInfo(outputDevice, 
context.getPartitionTimeFilter()));
       if (regionReplicaSets.size() > 1) {
         // specialProcess and existDeviceCrossRegion, use the old aggregation 
logic
-        analysis.setExistDeviceCrossRegion();
-        // TODO group by session, variation, count, count_if no not use old 
logic
-        // if (analysis.isDeviceViewSpecialProcess()) {
-        //  return processSpecialDeviceView(node, context);
-        // }
+        if (!analysis.isExistDeviceCrossRegion()) {
+          analysis.setExistDeviceCrossRegion();
+          // `group by session, variation, count; count_if` can not use 
AggMergeSort, it uses old
+          // aggregation logic
+          if (!analysis.hasGroupByParameter()
+              && 
!hasCountIfAggregation(analysis.getDeviceViewOutputExpressions())) {
+            analysis.setUseAggMergeSort();
+          }
+        }
+        if (analysis.isDeviceViewSpecialProcess() && 
!analysis.isUseAggMergeSort()) {
+          return processSpecialDeviceView(node, context);
+        }
       }
       deviceViewSplits.add(new DeviceViewSplit(outputDevice, child, 
regionReplicaSets));
       relatedDataRegions.addAll(regionReplicaSets);
@@ -227,10 +235,8 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
     }
 
     if (analysis.isExistDeviceCrossRegion() && 
analysis.isDeviceViewSpecialProcess()) {
-      // return processSpecialDeviceView(node, context);
-
-      // TODO 1. generate old and new measurement idx relationship
-      // TODO 2. generate new outputColumns for
+      // 1. generate old and new measurement idx relationship
+      // 2. generate new outputColumns
       // each subDeviceView
       Map<Integer, List<Integer>> newMeasurementIdxMap = new HashMap<>();
       List<String> newPartialOutputColumns = new ArrayList<>();
@@ -241,7 +247,6 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
       int i = 0, idxSum = 0;
       for (Expression expression : selectExpressions) {
         if (i == 0) {
-          // device
           newPartialOutputColumns.add(expression.getOutputSymbol());
           i++;
           idxSum++;
@@ -268,7 +273,6 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
         }
 
         newAggregationIdx[i] = actualPartialAggregationNames.size();
-        // TODO need update typeProvider?
         if (actualPartialAggregationNames.size() > 1) {
           newMeasurementIdxMap.put(i, Arrays.asList(idxSum++, idxSum++));
         } else {
@@ -289,7 +293,7 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
       for (PlanNode planNode : deviceViewNodeList) {
         DeviceViewNode deviceViewNode = (DeviceViewNode) planNode;
         deviceViewNode.setOutputColumnNames(newPartialOutputColumns);
-        transferAggregatorsRecursively2(planNode, context);
+        transferAggregatorsRecursively(planNode, context);
       }
 
       AggregationMergeSortNode mergeSortNode =
@@ -398,7 +402,7 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
     return outputAggregationNames;
   }
 
-  private void transferAggregatorsRecursively2(PlanNode planNode, 
DistributionPlanContext context) {
+  private void transferAggregatorsRecursively(PlanNode planNode, 
DistributionPlanContext context) {
     if (planNode instanceof SeriesAggregationSourceNode) {
       SeriesAggregationSourceNode scanSourceNode = 
(SeriesAggregationSourceNode) planNode;
       for (AggregationDescriptor descriptor : 
scanSourceNode.getAggregationDescriptorList()) {
@@ -408,38 +412,19 @@ public class SourceRewriter extends 
BaseSourceRewriter<DistributionPlanContext>
     }
 
     for (PlanNode child : planNode.getChildren()) {
-      transferAggregatorsRecursively2(child, context);
+      transferAggregatorsRecursively(child, context);
     }
   }
 
-  private void transferAggregatorsRecursively(PlanNode planNode) {
-    for (PlanNode child : planNode.getChildren()) {
-      transferAggregatorsRecursively(child);
-
-      if (child instanceof SeriesAggregationSourceNode) {
-        SeriesAggregationSourceNode scanSourceNode = 
(SeriesAggregationSourceNode) child;
-        List<AggregationDescriptor> newDescriptorList = new ArrayList<>();
-        for (AggregationDescriptor descriptor : 
scanSourceNode.getAggregationDescriptorList()) {
-          List<String> aggregationNames = 
descriptor.getActualAggregationNames(true);
-          for (String aggregationName : aggregationNames) {
-            newDescriptorList.add(
-                new AggregationDescriptor(
-                    aggregationName,
-                    AggregationStep.PARTIAL,
-                    descriptor.getInputExpressions(),
-                    descriptor.getInputAttributes()));
-          }
+  private boolean hasCountIfAggregation(Set<Expression> selectExpressions) {
+    for (Expression e : selectExpressions) {
+      if (e instanceof FunctionExpression) {
+        if (COUNT_IF.equalsIgnoreCase(((FunctionExpression) 
e).getFunctionName())) {
+          return true;
         }
-        scanSourceNode.setAggregationDescriptorList(newDescriptorList);
       }
     }
-  }
-
-  @Override
-  public List<PlanNode> visitAggregationMergeSort(
-      AggregationMergeSortNode node, DistributionPlanContext context) {
-    // TODO remove this method?
-    return null;
+    return false;
   }
 
   private List<PlanNode> processSpecialDeviceView(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AggregationMergeSortNode.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AggregationMergeSortNode.java
index dd7dc6cd404..39968fd3d7f 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AggregationMergeSortNode.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AggregationMergeSortNode.java
@@ -31,6 +31,7 @@ import java.io.DataOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
@@ -119,6 +120,10 @@ public class AggregationMergeSortNode extends 
MultiChildProcessNode {
     for (String column : outputColumns) {
       ReadWriteIOUtils.write(column, byteBuffer);
     }
+    ReadWriteIOUtils.write(selectExpressions.size(), byteBuffer);
+    for (Expression expression : selectExpressions) {
+      Expression.serialize(expression, byteBuffer);
+    }
   }
 
   @Override
@@ -129,6 +134,10 @@ public class AggregationMergeSortNode extends 
MultiChildProcessNode {
     for (String column : outputColumns) {
       ReadWriteIOUtils.write(column, stream);
     }
+    ReadWriteIOUtils.write(selectExpressions.size(), stream);
+    for (Expression expression : selectExpressions) {
+      Expression.serialize(expression, stream);
+    }
   }
 
   public static AggregationMergeSortNode deserialize(ByteBuffer byteBuffer) {
@@ -139,8 +148,15 @@ public class AggregationMergeSortNode extends 
MultiChildProcessNode {
       outputColumns.add(ReadWriteIOUtils.readString(byteBuffer));
       columnSize--;
     }
+    Set<Expression> expressions = new LinkedHashSet<>();
+    int expressionSize = ReadWriteIOUtils.readInt(byteBuffer);
+    while (expressionSize > 0) {
+      expressions.add(Expression.deserialize(byteBuffer));
+      expressionSize--;
+    }
     PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
-    return new AggregationMergeSortNode(planNodeId, orderByParameter, 
outputColumns, null, null);
+    return new AggregationMergeSortNode(
+        planNodeId, orderByParameter, outputColumns, expressions, null);
   }
 
   @Override

Reply via email to