This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d6a38c0a1 [multi-stage] Support SetOperations(UNION/INTERSECT/MINUS) 
compilation in query planner (#10535)
2d6a38c0a1 is described below

commit 2d6a38c0a13aecb055ea1d81efdc4cc8dfba197a
Author: Xiang Fu <[email protected]>
AuthorDate: Wed Apr 19 23:04:37 2023 -0700

    [multi-stage] Support SetOperations(UNION/INTERSECT/MINUS) compilation in 
query planner (#10535)
---
 .../apache/pinot/sql/parsers/CalciteSqlParser.java |  9 +-
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  | 30 +++++++
 .../query/planner/ExplainPlanStageVisitor.java     | 10 +++
 .../query/planner/logical/RelToStageConverter.java |  9 ++
 .../planner/logical/ShuffleRewriteVisitor.java     |  8 ++
 .../planner/physical/DispatchablePlanVisitor.java  |  8 ++
 .../colocated/GreedyShuffleRewriteContext.java     | 19 +++++
 .../GreedyShuffleRewritePreComputeVisitor.java     |  8 ++
 .../colocated/GreedyShuffleRewriteVisitor.java     | 73 +++++++++--------
 .../stage/DefaultPostOrderTraversalVisitor.java    |  6 ++
 .../pinot/query/planner/stage/SetOpNode.java       | 84 +++++++++++++++++++
 .../query/planner/stage/StageNodeSerDeUtils.java   |  2 +
 .../query/planner/stage/StageNodeVisitor.java      |  2 +
 .../pinot/query/QueryEnvironmentTestBase.java      |  5 ++
 .../src/test/resources/queries/SetOpPlans.json     | 95 ++++++++++++++++++++++
 .../query/runtime/plan/PhysicalPlanVisitor.java    | 11 ++-
 .../runtime/plan/ServerRequestPlanVisitor.java     |  6 ++
 17 files changed, 347 insertions(+), 38 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index c708c3930c..3a19108041 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -158,6 +158,9 @@ public class CalciteSqlParser {
         } else {
           tableNames.addAll(extractTableNamesFromNode(right));
         }
+      } else if ((fromNode instanceof SqlBasicCall)
+          && (((SqlBasicCall) fromNode).getOperator() instanceof 
SqlAsOperator)) {
+        tableNames.addAll(extractTableNamesFromNode(((SqlBasicCall) 
fromNode).getOperandList().get(0)));
       } else {
         tableNames.addAll(((SqlIdentifier) fromNode).names);
         tableNames.addAll(extractTableNamesFromNode(((SqlSelect) 
sqlNode).getWhere()));
@@ -178,12 +181,16 @@ public class CalciteSqlParser {
     } else if (sqlNode instanceof SqlWith) {
       List<SqlNode> withList = ((SqlWith) sqlNode).withList;
       Set<String> aliases = new HashSet<>();
-      for (SqlNode withItem: withList) {
+      for (SqlNode withItem : withList) {
         aliases.addAll(((SqlWithItem) withItem).name.names);
         tableNames.addAll(extractTableNamesFromNode(((SqlWithItem) 
withItem).query));
       }
       tableNames.addAll(extractTableNamesFromNode(((SqlWith) sqlNode).body));
       tableNames.removeAll(aliases);
+    } else if (sqlNode instanceof SqlSetOption) {
+      for (SqlNode node : ((SqlSetOption) sqlNode).getOperandList()) {
+        tableNames.addAll(extractTableNamesFromNode(node));
+      }
     } else if (sqlNode instanceof SqlExplain) {
       SqlExplain explain = (SqlExplain) sqlNode;
       tableNames.addAll(extractTableNamesFromNode(explain.getExplicandum()));
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index cc11d9cfb4..f45f611d50 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -3072,6 +3072,36 @@ public class CalciteSqlCompilerTest {
     Assert.assertEquals(tableNames.get(0), "tbl1");
     Assert.assertEquals(tableNames.get(1), "tbl2");
 
+    // query with UNION clause
+    query = "SELECT * FROM tbl1 UNION ALL SELECT * FROM tbl2 UNION ALL SELECT 
* FROM tbl3";
+    sqlNodeAndOptions = RequestUtils.parseQuery(query);
+    tableNames = 
CalciteSqlParser.extractTableNamesFromNode(sqlNodeAndOptions.getSqlNode());
+    Assert.assertEquals(tableNames.size(), 3);
+    Collections.sort(tableNames);
+    Assert.assertEquals(tableNames.get(0), "tbl1");
+    Assert.assertEquals(tableNames.get(1), "tbl2");
+    Assert.assertEquals(tableNames.get(2), "tbl3");
+
+    // query with UNION clause and table alias
+    query = "SELECT * FROM (SELECT * FROM tbl1) AS t1 UNION SELECT * FROM ( 
SELECT * FROM tbl2) AS t2";
+    sqlNodeAndOptions = RequestUtils.parseQuery(query);
+    tableNames = 
CalciteSqlParser.extractTableNamesFromNode(sqlNodeAndOptions.getSqlNode());
+    Assert.assertEquals(tableNames.size(), 2);
+    Collections.sort(tableNames);
+    Assert.assertEquals(tableNames.get(0), "tbl1");
+    Assert.assertEquals(tableNames.get(1), "tbl2");
+
+    // query with UNION clause and table alias using WITH clause
+    query = "WITH tmp1 AS (SELECT * FROM tbl1), \n"
+        + "tmp2 AS (SELECT * FROM tbl2) \n"
+        + "SELECT * FROM tmp1 UNION ALL SELECT * FROM tmp2";
+    sqlNodeAndOptions = RequestUtils.parseQuery(query);
+    tableNames = 
CalciteSqlParser.extractTableNamesFromNode(sqlNodeAndOptions.getSqlNode());
+    Assert.assertEquals(tableNames.size(), 2);
+    Collections.sort(tableNames);
+    Assert.assertEquals(tableNames.get(0), "tbl1");
+    Assert.assertEquals(tableNames.get(1), "tbl2");
+
     // query with aliases, JOIN, IN/NOT-IN, group-by
     query = "with tmp as (select col1, count(*) from tbl1 where condition1 = 
filter1 group by col1), "
         + "tmp2 as (select A.col1, B.col2 from tbl2 as A JOIN tbl3 AS B on 
A.key = B.key) "
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
index 1b4b46795d..8f02f7018b 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
@@ -29,6 +29,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -118,6 +119,15 @@ public class ExplainPlanStageVisitor implements 
StageNodeVisitor<StringBuilder,
     return visitSimpleNode(node, context);
   }
 
+  @Override
+  public StringBuilder visitSetOp(SetOpNode setOpNode, Context context) {
+    appendInfo(setOpNode, context).append('\n');
+    for (StageNode input : setOpNode.getInputs()) {
+      input.visit(this, context.next(false, context._host));
+    }
+    return context._builder;
+  }
+
   @Override
   public StringBuilder visitFilter(FilterNode node, Context context) {
     return visitSimpleNode(node, context);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
index 4b8c07ed1a..576f209e8c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
@@ -23,6 +23,7 @@ import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.JoinInfo;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.SetOp;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalJoin;
@@ -42,6 +43,7 @@ import org.apache.pinot.query.planner.stage.AggregateNode;
 import org.apache.pinot.query.planner.stage.FilterNode;
 import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.TableScanNode;
@@ -84,11 +86,18 @@ public final class RelToStageConverter {
       return convertLogicalValues((LogicalValues) node, currentStageId);
     } else if (node instanceof LogicalWindow) {
       return convertLogicalWindow((LogicalWindow) node, currentStageId);
+    } else if (node instanceof SetOp) {
+      return convertLogicalSetOp((SetOp) node, currentStageId);
     } else {
       throw new UnsupportedOperationException("Unsupported logical plan node: 
" + node);
     }
   }
 
+  private static StageNode convertLogicalSetOp(SetOp node, int currentStageId) 
{
+    return new SetOpNode(SetOpNode.SetOpType.fromObject(node), currentStageId, 
toDataSchema(node.getRowType()),
+        node.all);
+  }
+
   private static StageNode convertLogicalValues(LogicalValues node, int 
currentStageId) {
     return new ValueNode(currentStageId, toDataSchema(node.getRowType()), 
node.tuples);
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
index 12fa9d6cac..6f5376c1f5 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
@@ -32,6 +32,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -79,6 +80,13 @@ public class ShuffleRewriteVisitor implements 
StageNodeVisitor<Set<Integer>, Voi
     throw new UnsupportedOperationException("Window not yet supported!");
   }
 
+  @Override
+  public Set<Integer> visitSetOp(SetOpNode setOpNode, Void context) {
+    Set<Integer> newPartitionKeys = new HashSet<>();
+    setOpNode.getInputs().forEach(input -> 
newPartitionKeys.addAll(input.visit(this, context)));
+    return newPartitionKeys;
+  }
+
   @Override
   public Set<Integer> visitFilter(FilterNode node, Void context) {
     // filters don't change partition keys
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
index d7b1343141..ea4eaf7ac7 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
@@ -25,6 +25,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -82,6 +83,13 @@ public class DispatchablePlanVisitor implements 
StageNodeVisitor<Void, Dispatcha
     return null;
   }
 
+  @Override
+  public Void visitSetOp(SetOpNode setOpNode, DispatchablePlanContext context) 
{
+    setOpNode.getInputs().forEach(input -> input.visit(this, context));
+    getStageMetadata(setOpNode, context);
+    return null;
+  }
+
   @Override
   public Void visitFilter(FilterNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteContext.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteContext.java
index befcf958cb..23a752a991 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteContext.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteContext.java
@@ -35,6 +35,8 @@ class GreedyShuffleRewriteContext {
   private final Map<Integer, StageNode> _rootStageNode;
   private final Map<Integer, List<StageNode>> _leafNodes;
   private final Set<Integer> _joinStages;
+  private final Set<Integer> _setOpStages;
+
   /**
    * A map to track the partition keys for the input to the MailboxSendNode of 
a given stageId. This is needed
    * because the {@link GreedyShuffleRewriteVisitor} doesn't determine the 
distribution of the sender if the receiver
@@ -46,6 +48,7 @@ class GreedyShuffleRewriteContext {
     _rootStageNode = new HashMap<>();
     _leafNodes = new HashMap<>();
     _joinStages = new HashSet<>();
+    _setOpStages = new HashSet<>();
     _senderInputColocationKeys = new HashMap<>();
   }
 
@@ -92,6 +95,22 @@ class GreedyShuffleRewriteContext {
     return _joinStages.contains(stageId);
   }
 
+
+  /**
+   * {@link GreedyShuffleRewriteContext} allows checking whether a given 
stageId has a SetOpNode or not. During
+   * pre-computation, this method may be used to mark that the given stageId 
has a SetOpNode.
+   */
+  void markSetOpStage(Integer stageId) {
+    _setOpStages.add(stageId);
+  }
+
+  /**
+   * Returns true if the given stageId has a SetOpNode.
+   */
+  boolean isSetOpStage(Integer stageId) {
+    return _setOpStages.contains(stageId);
+  }
+
   /**
    * This returns the {@link Set<ColocationKey>} for the input to the {@link 
MailboxSendNode} of the given stageId.
    */
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewritePreComputeVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewritePreComputeVisitor.java
index d8e67bb598..d42fcfa4ce 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewritePreComputeVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewritePreComputeVisitor.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.planner.physical.colocated;
 import org.apache.pinot.query.planner.stage.DefaultPostOrderTraversalVisitor;
 import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.TableScanNode;
 
@@ -64,4 +65,11 @@ class GreedyShuffleRewritePreComputeVisitor
     context.addLeafNode(stageNode.getStageId(), stageNode);
     return 0;
   }
+
+  @Override
+  public Integer visitSetOp(SetOpNode setOpNode, GreedyShuffleRewriteContext 
context) {
+    super.visitSetOp(setOpNode, context);
+    context.markSetOpStage(setOpNode.getStageId());
+    return 0;
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
index 2ba7dedbbf..b5d77c1193 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.physical.colocated;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -39,6 +40,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -64,8 +66,7 @@ import org.slf4j.LoggerFactory;
  *
  * Also see: {@link ColocationKey} for its definition.
  */
-public class GreedyShuffleRewriteVisitor
-    implements StageNodeVisitor<Set<ColocationKey>, 
GreedyShuffleRewriteContext> {
+public class GreedyShuffleRewriteVisitor implements 
StageNodeVisitor<Set<ColocationKey>, GreedyShuffleRewriteContext> {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(GreedyShuffleRewriteVisitor.class);
 
   private final TableCache _tableCache;
@@ -114,8 +115,8 @@ public class GreedyShuffleRewriteVisitor
 
   @Override
   public Set<ColocationKey> visitJoin(JoinNode node, 
GreedyShuffleRewriteContext context) {
-    List<MailboxReceiveNode> innerLeafNodes = 
context.getLeafNodes(node.getStageId()).stream()
-        .map(x -> (MailboxReceiveNode) x).collect(Collectors.toList());
+    List<MailboxReceiveNode> innerLeafNodes =
+        context.getLeafNodes(node.getStageId()).stream().map(x -> 
(MailboxReceiveNode) x).collect(Collectors.toList());
     Preconditions.checkState(innerLeafNodes.size() == 2);
 
     // Multiple checks need to be made to ensure that shuffle can be skipped 
for a join.
@@ -123,8 +124,9 @@ public class GreedyShuffleRewriteVisitor
     boolean canColocate = canJoinBeColocated(node);
     // Step-2: Only if the servers assigned to both left and right nodes are 
equal and the servers assigned to the join
     //         stage are a superset of those servers, can we skip shuffles.
-    canColocate = canColocate && 
canServerAssignmentAllowShuffleSkip(node.getStageId(),
-        innerLeafNodes.get(0).getSenderStageId(), 
innerLeafNodes.get(1).getSenderStageId());
+    canColocate =
+        canColocate && canServerAssignmentAllowShuffleSkip(node.getStageId(), 
innerLeafNodes.get(0).getSenderStageId(),
+            innerLeafNodes.get(1).getSenderStageId());
     // Step-3: For both left/right MailboxReceiveNode/MailboxSendNode pairs, 
check whether the key partitioning can
     //         allow shuffle skip.
     canColocate = canColocate && 
partitionKeyConditionForJoin(innerLeafNodes.get(0),
@@ -136,8 +138,8 @@ public class GreedyShuffleRewriteVisitor
     canColocate = canColocate && checkPartitionScheme(innerLeafNodes.get(0), 
innerLeafNodes.get(1), context);
     if (canColocate) {
       // If shuffle can be skipped, reassign servers.
-      _stageMetadataMap.get(node.getStageId()).setServerInstances(
-          
_stageMetadataMap.get(innerLeafNodes.get(0).getSenderStageId()).getServerInstances());
+      _stageMetadataMap.get(node.getStageId())
+          
.setServerInstances(_stageMetadataMap.get(innerLeafNodes.get(0).getSenderStageId()).getServerInstances());
       _canSkipShuffleForJoin = true;
     }
 
@@ -148,8 +150,8 @@ public class GreedyShuffleRewriteVisitor
     Set<ColocationKey> colocationKeys = new HashSet<>(leftPKs);
 
     for (ColocationKey rightColocationKey : rightPks) {
-      ColocationKey newColocationKey = new 
ColocationKey(rightColocationKey.getNumPartitions(),
-          rightColocationKey.getHashAlgorithm());
+      ColocationKey newColocationKey =
+          new ColocationKey(rightColocationKey.getNumPartitions(), 
rightColocationKey.getHashAlgorithm());
       for (Integer index : rightColocationKey.getIndices()) {
         newColocationKey.addIndex(leftDataSchemaSize + index);
       }
@@ -167,18 +169,17 @@ public class GreedyShuffleRewriteVisitor
     if (!context.isJoinStage(node.getStageId())) {
       if (selector == null) {
         return new HashSet<>();
-      } else if (colocationKeyCondition(oldColocationKeys, selector)
-          && areServersSuperset(node.getStageId(), node.getSenderStageId())) {
+      } else if (colocationKeyCondition(oldColocationKeys, selector) && 
areServersSuperset(node.getStageId(),
+          node.getSenderStageId())) {
         node.setExchangeType(RelDistribution.Type.SINGLETON);
-        _stageMetadataMap.get(node.getStageId()).setServerInstances(
-            
_stageMetadataMap.get(node.getSenderStageId()).getServerInstances());
+        _stageMetadataMap.get(node.getStageId())
+            
.setServerInstances(_stageMetadataMap.get(node.getSenderStageId()).getServerInstances());
         return oldColocationKeys;
       }
       // This means we can't skip shuffle and there's a partitioning enforced 
by receiver.
       int numPartitions = 
_stageMetadataMap.get(node.getStageId()).getServerInstances().size();
-      List<ColocationKey> colocationKeys =
-          ((FieldSelectionKeySelector) selector).getColumnIndices().stream()
-              .map(x -> new ColocationKey(x, numPartitions, 
selector.hashAlgorithm())).collect(Collectors.toList());
+      List<ColocationKey> colocationKeys = ((FieldSelectionKeySelector) 
selector).getColumnIndices().stream()
+          .map(x -> new ColocationKey(x, numPartitions, 
selector.hashAlgorithm())).collect(Collectors.toList());
       return new HashSet<>(colocationKeys);
     }
     // If the current stage is a join-stage then we already know whether 
shuffle can be skipped.
@@ -193,9 +194,8 @@ public class GreedyShuffleRewriteVisitor
     }
     // This means we can't skip shuffle and there's a partitioning enforced by 
receiver.
     int numPartitions = 
_stageMetadataMap.get(node.getStageId()).getServerInstances().size();
-    List<ColocationKey> colocationKeys =
-        ((FieldSelectionKeySelector) selector).getColumnIndices().stream()
-            .map(x -> new ColocationKey(x, numPartitions, 
selector.hashAlgorithm())).collect(Collectors.toList());
+    List<ColocationKey> colocationKeys = ((FieldSelectionKeySelector) 
selector).getColumnIndices().stream()
+        .map(x -> new ColocationKey(x, numPartitions, 
selector.hashAlgorithm())).collect(Collectors.toList());
     return new HashSet<>(colocationKeys);
   }
 
@@ -251,10 +251,14 @@ public class GreedyShuffleRewriteVisitor
     return node.getInputs().get(0).visit(this, context);
   }
 
+  @Override
+  public Set<ColocationKey> visitSetOp(SetOpNode setOpNode, 
GreedyShuffleRewriteContext context) {
+    return ImmutableSet.of();
+  }
+
   @Override
   public Set<ColocationKey> visitTableScan(TableScanNode node, 
GreedyShuffleRewriteContext context) {
-    TableConfig tableConfig =
-        _tableCache.getTableConfig(node.getTableName());
+    TableConfig tableConfig = _tableCache.getTableConfig(node.getTableName());
     if (tableConfig == null) {
       LOGGER.warn("Couldn't find tableConfig for {}", node.getTableName());
       return new HashSet<>();
@@ -294,8 +298,8 @@ public class GreedyShuffleRewriteVisitor
    * Checks if servers assigned to the receiver stage are a super-set of the 
sender stage.
    */
   private boolean areServersSuperset(int receiverStageId, int senderStageId) {
-    return 
_stageMetadataMap.get(receiverStageId).getServerInstances().containsAll(
-        _stageMetadataMap.get(senderStageId).getServerInstances());
+    return _stageMetadataMap.get(receiverStageId).getServerInstances()
+        
.containsAll(_stageMetadataMap.get(senderStageId).getServerInstances());
   }
 
   /*
@@ -308,8 +312,8 @@ public class GreedyShuffleRewriteVisitor
     List<VirtualServer> rightServerInstances = 
_stageMetadataMap.get(rightStageId).getServerInstances();
     List<VirtualServer> currentServerInstances = 
_stageMetadataMap.get(currentStageId).getServerInstances();
     return leftServerInstances.containsAll(rightServerInstances)
-        && leftServerInstances.size() == rightServerInstances.size()
-        && currentServerInstances.containsAll(leftServerInstances);
+        && leftServerInstances.size() == rightServerInstances.size() && 
currentServerInstances.containsAll(
+        leftServerInstances);
   }
 
   /**
@@ -323,9 +327,9 @@ public class GreedyShuffleRewriteVisitor
     Set<ColocationKey> colocationKeys = new HashSet<>();
     for (ColocationKey colocationKey : oldColocationKeys) {
       boolean shouldDrop = false;
-      ColocationKey newColocationKey
-          = new ColocationKey(colocationKey.getNumPartitions(), 
colocationKey.getHashAlgorithm());
-      for (Integer index: colocationKey.getIndices()) {
+      ColocationKey newColocationKey =
+          new ColocationKey(colocationKey.getNumPartitions(), 
colocationKey.getHashAlgorithm());
+      for (Integer index : colocationKey.getIndices()) {
         if (!oldToNewIndex.containsKey(index)) {
           shouldDrop = true;
           break;
@@ -355,8 +359,7 @@ public class GreedyShuffleRewriteVisitor
   }
 
   private static boolean partitionKeyConditionForJoin(MailboxReceiveNode 
mailboxReceiveNode,
-      MailboxSendNode mailboxSendNode,
-      GreedyShuffleRewriteContext context) {
+      MailboxSendNode mailboxSendNode, GreedyShuffleRewriteContext context) {
     // First check ColocationKeyCondition for the sender <--> 
sender.getInputs().get(0) pair
     Set<ColocationKey> oldColocationKeys = 
context.getColocationKeys(mailboxSendNode.getStageId());
     KeySelector<Object[], Object[]> selector = 
mailboxSendNode.getPartitionKeySelector();
@@ -387,10 +390,10 @@ public class GreedyShuffleRewriteVisitor
       GreedyShuffleRewriteContext context) {
     int leftSender = leftReceiveNode.getSenderStageId();
     int rightSender = rightReceiveNode.getSenderStageId();
-    ColocationKey leftPKey = 
getEquivalentSenderKey(context.getColocationKeys(leftSender),
-        leftReceiveNode.getPartitionKeySelector());
-    ColocationKey rightPKey = 
getEquivalentSenderKey(context.getColocationKeys(rightSender),
-        rightReceiveNode.getPartitionKeySelector());
+    ColocationKey leftPKey =
+        getEquivalentSenderKey(context.getColocationKeys(leftSender), 
leftReceiveNode.getPartitionKeySelector());
+    ColocationKey rightPKey =
+        getEquivalentSenderKey(context.getColocationKeys(rightSender), 
rightReceiveNode.getPartitionKeySelector());
     if (leftPKey.getNumPartitions() != rightPKey.getNumPartitions()) {
       return false;
     }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/DefaultPostOrderTraversalVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/DefaultPostOrderTraversalVisitor.java
index d41b0fe0bc..5e8f604bf0 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/DefaultPostOrderTraversalVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/DefaultPostOrderTraversalVisitor.java
@@ -84,4 +84,10 @@ public abstract class DefaultPostOrderTraversalVisitor<T, C> 
implements StageNod
     node.getInputs().get(0).visit(this, context);
     return process(node, context);
   }
+
+  @Override
+  public T visitSetOp(SetOpNode node, C context) {
+    node.getInputs().forEach(input -> input.visit(this, context));
+    return process(node, context);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/SetOpNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/SetOpNode.java
new file mode 100644
index 0000000000..aa7003449f
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/SetOpNode.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.stage;
+
+import org.apache.calcite.rel.core.SetOp;
+import org.apache.calcite.rel.logical.LogicalIntersect;
+import org.apache.calcite.rel.logical.LogicalMinus;
+import org.apache.calcite.rel.logical.LogicalUnion;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
+
+
+/**
+ * Set operation node is used to represent UNION, INTERSECT, EXCEPT.
+ */
+public class SetOpNode extends AbstractStageNode {
+
+  @ProtoProperties
+  private SetOpType _setOpType;
+
+  @ProtoProperties
+  private boolean _all;
+
+  public SetOpNode(int stageId) {
+    super(stageId);
+  }
+
+  public SetOpNode(SetOpType setOpType, int stageId, DataSchema dataSchema, 
boolean all) {
+    super(stageId, dataSchema);
+    _setOpType = setOpType;
+    _all = all;
+  }
+
+  public SetOpType getSetOpType() {
+    return _setOpType;
+  }
+
+  public boolean isAll() {
+    return _all;
+  }
+
+  @Override
+  public String explain() {
+    return _setOpType.toString();
+  }
+
+  @Override
+  public <T, C> T visit(StageNodeVisitor<T, C> visitor, C context) {
+    return visitor.visitSetOp(this, context);
+  }
+
+  public enum SetOpType {
+    UNION, INTERSECT, MINUS;
+
+    public static SetOpType fromObject(SetOp setOp) {
+      if (setOp instanceof LogicalUnion) {
+        return UNION;
+      }
+      if (setOp instanceof LogicalIntersect) {
+        return INTERSECT;
+      }
+      if (setOp instanceof LogicalMinus) {
+        return MINUS;
+      }
+      throw new IllegalArgumentException("Unsupported set operation: " + 
setOp.getClass());
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
index 709c085170..45a9b8c1df 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
@@ -85,6 +85,8 @@ public final class StageNodeSerDeUtils {
         return new ValueNode(stageId);
       case "WindowNode":
         return new WindowNode(stageId);
+      case "SetOpNode":
+        return new SetOpNode(stageId);
       default:
         throw new IllegalArgumentException("Unknown node name: " + nodeName);
     }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeVisitor.java
index 593cd336fc..78acf94c76 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeVisitor.java
@@ -56,4 +56,6 @@ public interface StageNodeVisitor<T, C> {
   T visitValue(ValueNode node, C context);
 
   T visitWindow(WindowNode node, C context);
+
+  T visitSetOp(SetOpNode setOpNode, C context);
 }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 4ee591c9cc..4a64e632aa 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -73,6 +73,11 @@ public class QueryEnvironmentTestBase {
   @DataProvider(name = "testQueryDataProvider")
   protected Object[][] provideQueries() {
     return new Object[][] {
+        new Object[]{"SELECT * FROM a UNION SELECT * FROM b"},
+        new Object[]{"SELECT * FROM a UNION ALL SELECT * FROM b"},
+        new Object[]{"SELECT * FROM a INTERSECT SELECT * FROM b"},
+        new Object[]{"SELECT * FROM a EXCEPT SELECT * FROM b"},
+        new Object[]{"SELECT * FROM a MINUS SELECT * FROM b"},
         new Object[]{"SELECT * FROM a ORDER BY col1 LIMIT 10"},
         new Object[]{"SELECT * FROM b ORDER BY col1, col2 DESC LIMIT 10"},
         new Object[]{"SELECT * FROM d"},
diff --git a/pinot-query-planner/src/test/resources/queries/SetOpPlans.json 
b/pinot-query-planner/src/test/resources/queries/SetOpPlans.json
new file mode 100644
index 0000000000..22bb56a78d
--- /dev/null
+++ b/pinot-query-planner/src/test/resources/queries/SetOpPlans.json
@@ -0,0 +1,95 @@
+{
+  "set_op_tests": {
+    "queries": [
+      {
+        "description": "UNION ALL from two tables",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a UNION ALL SELECT 
col1, col2 FROM b",
+        "output": [
+          "Execution Plan",
+          "\nLogicalUnion(all=[true])",
+          "\n  LogicalProject(col1=[$0], col2=[$1])",
+          "\n    LogicalTableScan(table=[[a]])",
+          "\n  LogicalProject(col1=[$0], col2=[$1])",
+          "\n    LogicalTableScan(table=[[b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "UNION ALL from three tables",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a UNION ALL SELECT 
col1, col2 FROM b UNION ALL SELECT col1, col2 FROM c",
+        "output": [
+          "Execution Plan",
+          "\nLogicalUnion(all=[true])",
+          "\n  LogicalUnion(all=[true])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[a]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[b]])",
+          "\n  LogicalProject(col1=[$0], col2=[$1])",
+          "\n    LogicalTableScan(table=[[c]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "UNION from three tables",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a UNION SELECT col1, 
col2 FROM b UNION SELECT col1, col2 FROM c",
+        "output": [
+          "Execution Plan",
+          "\nLogicalAggregate(group=[{0, 1}])",
+          "\n  LogicalExchange(distribution=[hash[0, 1]])",
+          "\n    LogicalAggregate(group=[{0, 1}])",
+          "\n      LogicalUnion(all=[true])",
+          "\n        LogicalUnion(all=[true])",
+          "\n          LogicalProject(col1=[$0], col2=[$1])",
+          "\n            LogicalTableScan(table=[[a]])",
+          "\n          LogicalProject(col1=[$0], col2=[$1])",
+          "\n            LogicalTableScan(table=[[b]])",
+          "\n        LogicalProject(col1=[$0], col2=[$1])",
+          "\n          LogicalTableScan(table=[[c]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "INTERSECT from three tables",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a INTERSECT SELECT 
col1, col2 FROM b INTERSECT SELECT col1, col2 FROM c",
+        "output": [
+          "Execution Plan",
+          "\nLogicalIntersect(all=[false])",
+          "\n  LogicalIntersect(all=[false])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[a]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[b]])",
+          "\n  LogicalProject(col1=[$0], col2=[$1])",
+          "\n    LogicalTableScan(table=[[c]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "EXCEPT from three tables",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a EXCEPT SELECT col1, 
col2 FROM b EXCEPT SELECT col1, col2 FROM c",
+        "output": [
+          "Execution Plan",
+          "\nLogicalMinus(all=[false])",
+          "\n  LogicalMinus(all=[false])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[a]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1])",
+          "\n      LogicalTableScan(table=[[b]])",
+          "\n  LogicalProject(col1=[$0], col2=[$1])",
+          "\n    LogicalTableScan(table=[[c]])",
+          "\n"
+        ]
+      }
+    ]
+  },
+  "exception_throwing_set_planning_tests": {
+    "queries": [
+      {
+        "description": "Incorrect selection list ",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col3 FROM a UNION ALL SELECT 
col1 FROM b",
+        "expectedException": "Error explain query plan for.*"
+      }
+    ]
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
index 1890f1306c..c323a6074d 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
@@ -24,6 +24,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -102,11 +103,17 @@ public class PhysicalPlanVisitor implements 
StageNodeVisitor<MultiStageOperator,
         node.getDataSchema(), node.getInputs().get(0).getDataSchema());
   }
 
+  @Override
+  public MultiStageOperator visitSetOp(SetOpNode setOpNode, PlanRequestContext 
context) {
+    throw new UnsupportedOperationException(
+        "Stage node of type SetOpNode: " + setOpNode.getSetOpType() + " is not 
supported!");
+  }
+
   @Override
   public MultiStageOperator visitFilter(FilterNode node, PlanRequestContext 
context) {
     MultiStageOperator nextOperator = node.getInputs().get(0).visit(this, 
context);
-    return new FilterOperator(context.getOpChainExecutionContext(),
-        nextOperator, node.getDataSchema(), node.getCondition());
+    return new FilterOperator(context.getOpChainExecutionContext(), 
nextOperator, node.getDataSchema(),
+        node.getCondition());
   }
 
   @Override
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
index 73dbc9acff..e7a6d5506c 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
@@ -43,6 +43,7 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.SetOpNode;
 import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.StageNodeVisitor;
@@ -185,6 +186,11 @@ public class ServerRequestPlanVisitor implements 
StageNodeVisitor<Void, ServerPl
     throw new UnsupportedOperationException("Window not yet supported!");
   }
 
+  @Override
+  public Void visitSetOp(SetOpNode setOpNode, ServerPlanRequestContext 
context) {
+    throw new UnsupportedOperationException("SetOp not yet supported!");
+  }
+
   @Override
   public Void visitFilter(FilterNode node, ServerPlanRequestContext context) {
     visitChildren(node, context);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to