This is an automated email from the ASF dual-hosted git repository.

stoty pushed a commit to branch 5.1
in repository https://gitbox.apache.org/repos/asf/phoenix.git

commit f58c09905ec032df6147a0548a043649a66a5d36
Author: chenglei <cheng...@apache.org>
AuthorDate: Wed Jul 28 12:47:54 2021 +0800

    PHOENIX-6498 Fix incorrect Correlated Exists Subquery rewrite when Subquery 
is aggregate
---
 .../apache/phoenix/end2end/join/SubqueryIT.java    | 102 +++++++++
 .../end2end/join/SubqueryUsingSortMergeJoinIT.java | 104 ++++++++-
 .../apache/phoenix/compile/SubqueryRewriter.java   | 248 ++++++++++++++++-----
 .../apache/phoenix/compile/QueryCompilerTest.java  | 118 +++++++++-
 4 files changed, 510 insertions(+), 62 deletions(-)

diff --git 
a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java 
b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
index 85828df..285cd03 100644
--- a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
+++ b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
@@ -813,6 +813,108 @@ public class SubqueryIT extends BaseJoinIT {
     }
 
     @Test
+    public void testCorrelatedExistsSubqueryBug6498() throws Exception {
+        Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
+        final Connection conn = DriverManager.getConnection(getUrl(), props);
+        String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME);
+        String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME);
+        try {
+            String query = "SELECT \"order_id\", name FROM " + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " group by q.\"customer_id\" having count(\"order_id\") > 
1)";
+            PreparedStatement statement = conn.prepareStatement(query);
+            ResultSet rs = statement.executeQuery();
+            assertFalse(rs.next());
+
+            query = "SELECT \"order_id\", name FROM " + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " group by q.\"customer_id\" having count(\"order_id\") >= 
1) order by \"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertEquals(rs.getString(2), "T1");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertEquals(rs.getString(2), "T6");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000003");
+            assertEquals(rs.getString(2), "T2");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertEquals(rs.getString(2), "T6");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertEquals(rs.getString(2), "T3");
+            assertFalse(rs.next());
+
+            query = "SELECT \"order_id\", name FROM " + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " and q.price <= 150 group by q.\"customer_id\" having 
count(\"order_id\") >= 1)"+
+                    " or o.quantity = 5000 order by \"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertEquals(rs.getString(2), "T1");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertEquals(rs.getString(2), "T3");
+            assertFalse(rs.next());
+
+            query = "SELECT \"order_id\" FROM " + tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000004' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") >= 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000003");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertFalse(rs.next());
+
+            query = "SELECT \"order_id\" FROM " + tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000003' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") >= 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertFalse(rs.next());
+
+            query = "SELECT \"order_id\" FROM " + tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000003' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") > 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertFalse(rs.next());
+        } finally {
+            conn.close();
+        }
+    }
+
+    @Test
     public void testAnyAllComparisonSubquery() throws Exception {
         Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
         Connection conn = DriverManager.getConnection(getUrl(), props);
diff --git 
a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
 
b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
index 33bab44..b1f56ed 100644
--- 
a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
+++ 
b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
@@ -356,7 +356,7 @@ public class SubqueryUsingSortMergeJoinIT extends 
BaseJoinIT {
             conn.close();
         }
     }
-    
+
     @Test
     public void testExistsSubquery() throws Exception {
         Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
@@ -602,6 +602,108 @@ public class SubqueryUsingSortMergeJoinIT extends 
BaseJoinIT {
     }
 
     @Test
+    public void testCorrelatedExistsSubqueryBug6498() throws Exception {
+        Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
+        final Connection conn = DriverManager.getConnection(getUrl(), props);
+        String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME);
+        String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME);
+        try {
+            String query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", 
name FROM " + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " group by q.\"customer_id\" having count(\"order_id\") > 
1)";
+            PreparedStatement statement = conn.prepareStatement(query);
+            ResultSet rs = statement.executeQuery();
+            assertFalse(rs.next());
+
+            query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM 
" + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " group by q.\"customer_id\" having count(\"order_id\") >= 
1) order by \"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertEquals(rs.getString(2), "T1");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertEquals(rs.getString(2), "T6");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000003");
+            assertEquals(rs.getString(2), "T2");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertEquals(rs.getString(2), "T6");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertEquals(rs.getString(2), "T3");
+            assertFalse(rs.next());
+
+            query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM 
" + tableName4 +
+                    " o JOIN " + tableName1 +
+                    " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " +
+                    "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = 
q.\"item_id\"" +
+                    " and q.price <= 150 group by q.\"customer_id\" having 
count(\"order_id\") >= 1)"+
+                    " or o.quantity = 5000 order by \"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertEquals(rs.getString(2), "T1");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertEquals(rs.getString(2), "T3");
+            assertFalse(rs.next());
+
+            query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + 
tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000004' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") >= 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000003");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertFalse(rs.next());
+
+            query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + 
tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000003' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") >= 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000001");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000002");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000004");
+            assertTrue (rs.next());
+            assertEquals(rs.getString(1), "000000000000005");
+            assertFalse(rs.next());
+
+            query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + 
tableName4 +
+                    " o WHERE exists (SELECT 1 FROM " + tableName4 +
+                    " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != 
'000000000000003' GROUP BY \"order_id\"" +
+                    " having count(\"customer_id\") > 1) order by  
\"order_id\"";
+            statement = conn.prepareStatement(query);
+            rs = statement.executeQuery();
+            assertFalse(rs.next());
+        } finally {
+            conn.close();
+        }
+    }
+
+    @Test
     public void testAnyAllComparisonSubquery() throws Exception {
         Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
         Connection conn = DriverManager.getConnection(getUrl(), props);
diff --git 
a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java 
b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
index ab9649e..beeac30 100644
--- 
a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
+++ 
b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
@@ -35,6 +35,7 @@ import org.apache.phoenix.parse.ArrayAnyComparisonNode;
 import org.apache.phoenix.parse.ColumnParseNode;
 import org.apache.phoenix.parse.ComparisonParseNode;
 import org.apache.phoenix.parse.CompoundParseNode;
+import org.apache.phoenix.parse.DerivedTableNode;
 import org.apache.phoenix.parse.ExistsParseNode;
 import org.apache.phoenix.parse.HintNode;
 import org.apache.phoenix.parse.InParseNode;
@@ -56,7 +57,7 @@ import org.apache.phoenix.schema.TableNotFoundException;
 
 import org.apache.phoenix.thirdparty.com.google.common.collect.Lists;
 
-/*
+/**
  * Class for rewriting where-clause sub-queries into join queries.
  * 
  * If the where-clause sub-query is one of those top-node conditions (being 
@@ -70,7 +71,7 @@ import 
org.apache.phoenix.thirdparty.com.google.common.collect.Lists;
 public class SubqueryRewriter extends ParseNodeRewriter {
     private static final ParseNodeFactory NODE_FACTORY = new 
ParseNodeFactory();
     
-    private final ColumnResolver resolver;
+    private final ColumnResolver columnResolver;
     private final PhoenixConnection connection;
     private TableNode tableNode;
     private ParseNode topNode;
@@ -89,7 +90,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
     }
     
     protected SubqueryRewriter(SelectStatement select, ColumnResolver 
resolver, PhoenixConnection connection) {
-        this.resolver = resolver;
+        this.columnResolver = resolver;
         this.connection = connection;
         this.tableNode = select.getFrom();
         this.topNode = null;
@@ -194,7 +195,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
 
         JoinConditionExtractor joinConditionExtractor = new 
JoinConditionExtractor(
                 subquerySelectStatementToUse,
-                resolver,
+                columnResolver,
                 connection,
                 subqueryTableTempAlias);
 
@@ -228,7 +229,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
              * It is an Correlated subquery.
              */
             List<AliasedNode> extractedAdditionalSelectAliasNodes =
-                    joinConditionExtractor.getAdditionalSelectNodes();
+                    
joinConditionExtractor.getAdditionalSubselectSelectAliasedNodes();
             extractedSelectAliasNodeCount = 
extractedAdditionalSelectAliasNodes.size();
             newSubquerySelectAliasedNodes = Lists.<AliasedNode> 
newArrayListWithExpectedSize(
                     oldSubqueryAliasedNodes.size() + 1 +
@@ -239,10 +240,11 @@ public class SubqueryRewriter extends ParseNodeRewriter {
                     LiteralParseNode.ONE));
             this.addNewAliasedNodes(newSubquerySelectAliasedNodes, 
oldSubqueryAliasedNodes);
             
newSubquerySelectAliasedNodes.addAll(extractedAdditionalSelectAliasNodes);
-            extractedJoinConditionParseNode = 
joinConditionExtractor.getJoinCondition();
+            extractedJoinConditionParseNode =
+                joinConditionExtractor.getJoinConditionParseNode();
 
             boolean isAggregate = subquerySelectStatementToUse.isAggregate();
-            if(!isAggregate) {
+            if (!isAggregate) {
                 subquerySelectStatementToUse =
                         NODE_FACTORY.select(
                                 subquerySelectStatementToUse,
@@ -274,7 +276,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
                 subqueryTableTempAlias,
                 extractedJoinConditionParseNode,
                 extractedSelectAliasNodeCount);
-        TableNode rhsTableNode = NODE_FACTORY.derivedTable(
+        DerivedTableNode subqueryDerivedTableNode = NODE_FACTORY.derivedTable(
                 subqueryTableTempAlias,
                 subquerySelectStatementToUse);
         JoinType joinType = isTopNode ?
@@ -291,45 +293,167 @@ public class SubqueryRewriter extends ParseNodeRewriter {
         tableNode = NODE_FACTORY.join(
                 joinType,
                 tableNode,
-                rhsTableNode,
+                subqueryDerivedTableNode,
                 joinOnConditionParseNode,
                 false);
 
         return resultWhereParseNode;
     }
 
+    /**
+     * <pre>
+     * {@code
+     * Rewrite the Exists Subquery to semi/anti/left join for both 
NonCorrelated and Correlated subquery.
+     *
+     * 1.If the {@link ExistsParseNode} is NonCorrelated subquery,the just add 
LIMIT 1.
+     *    an example is:
+     *    SELECT item_id, name FROM item i WHERE exists
+     *    (SELECT 1 FROM order o  where o.price > 8)
+     *
+     *    The above sql would be rewritten as:
+     *    SELECT ITEM_ID,NAME FROM item I  WHERE  EXISTS
+     *    (SELECT 1 FROM ORDER_TABLE O  WHERE O.PRICE > 8 LIMIT 1)
+     *
+     *   another example is:
+     *   SELECT item_id, name FROM item i WHERE exists
+     *   (SELECT 1 FROM order o  where o.price > 8 group by 
o.customer_id,o.item_id having count(order_id) > 1)
+     *    or i.discount1 > 10
+     *
+     *    The above sql would be rewritten as:
+     *    SELECT ITEM_ID,NAME FROM item I  WHERE
+     *    ( EXISTS (SELECT 1 FROM ORDER_TABLE O  WHERE O.PRICE > 8 GROUP BY 
O.CUSTOMER_ID,O.ITEM_ID HAVING  COUNT(ORDER_ID) > 1 LIMIT 1)
+     *    OR I.DISCOUNT1 > 10)
+     *
+     * 2.If the {@link ExistsParseNode} is Correlated subquery and is the only 
node in where clause or
+     *   is the ANDed part of the where clause, then we would rewrite the 
Exists Subquery to semi/anti join:
+     *   an example is:
+     *    SELECT item_id, name FROM item i WHERE exists
+     *    (SELECT 1 FROM order o where o.price = i.price and o.quantity = 5 )
+     *
+     *    The above sql would be rewritten as:
+     *    SELECT ITEM_ID,NAME FROM item I  Semi JOIN
+     *    (SELECT DISTINCT 1 $3,O.PRICE $2 FROM ORDER_TABLE O  WHERE 
O.QUANTITY = 5) $1
+     *    ON ($1.$2 = I.PRICE)
+     *
+     *   another example with AggregateFunction and groupBy is
+     *   SELECT item_id, name FROM item i WHERE exists
+     *   (SELECT 1 FROM order o  where o.item_id = i.item_id group by 
customer_id having count(order_id) > 1)
+     *
+     *    The above sql would be rewritten as:
+     *     SELECT ITEM_ID,NAME FROM item I  Semi JOIN
+     *     (SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM order O  GROUP BY 
O.ITEM_ID,CUSTOMER_ID HAVING  COUNT(ORDER_ID) > 1) $1
+     *     ON ($1.$2 = I.ITEM_ID)
+     *
+     * 3.If the {@link ExistsParseNode} is Correlated subquery and is the ORed 
part of the where clause,
+     *   then we would rewrite the Exists Subquery to Left Join.
+     *   an example is:
+     *   SELECT item_id, name FROM item i WHERE exists
+     *   (SELECT 1 FROM order o  where o.item_id = i.item_id group by 
customer_id having count(order_id) > 1)
+     *   or i.discount1 > 10
+     *
+     *    The above sql would be rewritten as:
+     *    SELECT ITEM_ID,NAME FROM item I  Left JOIN
+     *    (SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM order O  GROUP BY 
O.ITEM_ID,CUSTOMER_ID HAVING  COUNT(ORDER_ID) > 1) $1
+     *    ON ($1.$2 = I.ITEM_ID) WHERE ($1.$3 IS NOT NULL  OR I.DISCOUNT1 > 10)
+     * }
+     * </pre>
+     */
     @Override
-    public ParseNode visitLeave(ExistsParseNode node, List<ParseNode> l) 
throws SQLException {
-        boolean isTopNode = topNode == node;
+    public ParseNode visitLeave(
+        ExistsParseNode existsParseNode,
+        List<ParseNode> childParseNodes) throws SQLException {
+
+        boolean isTopNode = topNode == existsParseNode;
         if (isTopNode) {
             topNode = null;
         }
         
-        SubqueryParseNode subqueryNode = (SubqueryParseNode) l.get(0);
-        SelectStatement subquery = 
fixSubqueryStatement(subqueryNode.getSelectNode());
-        String rhsTableAlias = ParseNodeFactory.createTempAlias();
-        JoinConditionExtractor conditionExtractor = new 
JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias);
-        ParseNode where = subquery.getWhere() == null ? null : 
subquery.getWhere().accept(conditionExtractor);
-        if (where == subquery.getWhere()) { // non-correlated EXISTS subquery, 
add LIMIT 1
-            subquery = NODE_FACTORY.select(subquery, 
NODE_FACTORY.limit(NODE_FACTORY.literal(1)));
-            subqueryNode = NODE_FACTORY.subquery(subquery, false);
-            node = NODE_FACTORY.exists(subqueryNode, node.isNegate());
-            return super.visitLeave(node, Collections.<ParseNode> 
singletonList(subqueryNode));
-        }
-        
-        List<AliasedNode> additionalSelectNodes = 
conditionExtractor.getAdditionalSelectNodes();
-        List<AliasedNode> selectNodes = 
Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1);
-        
selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), 
LiteralParseNode.ONE));
-        selectNodes.addAll(additionalSelectNodes);
-        
-        subquery = NODE_FACTORY.select(subquery, true, selectNodes, where);
-        ParseNode onNode = conditionExtractor.getJoinCondition();
-        TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, 
subquery);
-        JoinType joinType = isTopNode ? (node.isNegate() ? JoinType.Anti : 
JoinType.Semi) : JoinType.Left;
-        ParseNode ret = isTopNode ? null : 
NODE_FACTORY.isNull(NODE_FACTORY.column(NODE_FACTORY.table(null, 
rhsTableAlias), selectNodes.get(0).getAlias(), null), !node.isNegate());
-        tableNode = NODE_FACTORY.join(joinType, tableNode, rhsTable, onNode, 
false);
+        SubqueryParseNode subqueryParseNode = (SubqueryParseNode) 
childParseNodes.get(0);
+        SelectStatement subquerySelectStatementToUse =
+                fixSubqueryStatement(subqueryParseNode.getSelectNode());
+        String subqueryTableTempAlias = ParseNodeFactory.createTempAlias();
+        JoinConditionExtractor joinConditionExtractor =
+                new JoinConditionExtractor(
+                        subquerySelectStatementToUse,
+                        columnResolver,
+                        connection,
+                        subqueryTableTempAlias);
+        ParseNode whereParseNodeAfterExtract =
+                subquerySelectStatementToUse.getWhere() == null ?
+                null :
+                
subquerySelectStatementToUse.getWhere().accept(joinConditionExtractor);
+        if (whereParseNodeAfterExtract == 
subquerySelectStatementToUse.getWhere()) {
+            /**
+             * It is non-correlated EXISTS subquery, add LIMIT 1
+             */
+            subquerySelectStatementToUse =
+                    NODE_FACTORY.select(
+                            subquerySelectStatementToUse,
+                            NODE_FACTORY.limit(NODE_FACTORY.literal(1)));
+            subqueryParseNode = 
NODE_FACTORY.subquery(subquerySelectStatementToUse, false);
+            existsParseNode = NODE_FACTORY.exists(subqueryParseNode, 
existsParseNode.isNegate());
+            return super.visitLeave(
+                    existsParseNode,
+                    Collections.<ParseNode>singletonList(subqueryParseNode));
+        }
+
+        List<AliasedNode> extractedAdditionalSelectAliasNodes =
+                
joinConditionExtractor.getAdditionalSubselectSelectAliasedNodes();
+        List<AliasedNode> newSubquerySelectAliasedNodes = 
Lists.newArrayListWithExpectedSize(
+                extractedAdditionalSelectAliasNodes.size() + 1);
+        /**
+         * Just overwrite original subquery selectAliasNodes.
+         */
+        newSubquerySelectAliasedNodes.add(
+                NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), 
LiteralParseNode.ONE));
+        
newSubquerySelectAliasedNodes.addAll(extractedAdditionalSelectAliasNodes);
         
-        return ret;
+        boolean isAggregate = subquerySelectStatementToUse.isAggregate();
+        if (!isAggregate) {
+            subquerySelectStatementToUse = NODE_FACTORY.select(
+                    subquerySelectStatementToUse,
+                    true,
+                    newSubquerySelectAliasedNodes,
+                    whereParseNodeAfterExtract);
+        } else {
+            /**
+             * If exists AggregateFunction,we must add the correlated join 
condition to both the
+             * groupBy clause and select lists of the subquery.
+             */
+            List<ParseNode> newGroupByParseNodes = 
this.createNewGroupByParseNodes(
+                    extractedAdditionalSelectAliasNodes,
+                    subquerySelectStatementToUse);
+
+            subquerySelectStatementToUse = NODE_FACTORY.select(
+                    subquerySelectStatementToUse,
+                    true,
+                    newSubquerySelectAliasedNodes,
+                    whereParseNodeAfterExtract,
+                    newGroupByParseNodes,
+                    true);
+        }
+        ParseNode joinOnConditionParseNode = 
joinConditionExtractor.getJoinConditionParseNode();
+        DerivedTableNode subqueryDerivedTableNode = NODE_FACTORY.derivedTable(
+                subqueryTableTempAlias,
+                subquerySelectStatementToUse);
+        JoinType joinType = isTopNode ?
+                (existsParseNode.isNegate() ? JoinType.Anti : JoinType.Semi) :
+                 JoinType.Left;
+        ParseNode resultWhereParseNode = isTopNode ?
+                        null :
+                        NODE_FACTORY.isNull(
+                                NODE_FACTORY.column(
+                                        NODE_FACTORY.table(null, 
subqueryTableTempAlias),
+                                        
newSubquerySelectAliasedNodes.get(0).getAlias(),
+                                        null),
+                                !existsParseNode.isNegate());
+        tableNode = NODE_FACTORY.join(
+                joinType,
+                tableNode,
+                subqueryDerivedTableNode,
+                joinOnConditionParseNode,
+                false);
+        return resultWhereParseNode;
     }
 
     @Override
@@ -347,7 +471,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
         SubqueryParseNode subqueryNode = (SubqueryParseNode) secondChild;
         SelectStatement subquery = 
fixSubqueryStatement(subqueryNode.getSelectNode());
         String rhsTableAlias = ParseNodeFactory.createTempAlias();
-        JoinConditionExtractor conditionExtractor = new 
JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias);
+        JoinConditionExtractor conditionExtractor = new 
JoinConditionExtractor(subquery, columnResolver, connection, rhsTableAlias);
         ParseNode where = subquery.getWhere() == null ? null : 
subquery.getWhere().accept(conditionExtractor);
         if (where == subquery.getWhere()) { // non-correlated comparison 
subquery, add LIMIT 2, expectSingleRow = true
             subquery = NODE_FACTORY.select(subquery, 
NODE_FACTORY.limit(NODE_FACTORY.literal(2)));
@@ -371,8 +495,10 @@ public class SubqueryRewriter extends ParseNodeRewriter {
             rhsNode = NODE_FACTORY.rowValueConstructor(nodes);
         }
         
-        List<AliasedNode> additionalSelectNodes = 
conditionExtractor.getAdditionalSelectNodes();
-        List<AliasedNode> selectNodes = 
Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1);        
+        List<AliasedNode> additionalSelectNodes =
+            conditionExtractor.getAdditionalSubselectSelectAliasedNodes();
+        List<AliasedNode> selectNodes =
+            Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 
1);
         
selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), 
rhsNode));
         selectNodes.addAll(additionalSelectNodes);
         
@@ -385,7 +511,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
             subquery = NODE_FACTORY.select(subquery, subquery.isDistinct(), 
selectNodes, where, groupbyNodes, true);
         }
         
-        ParseNode onNode = conditionExtractor.getJoinCondition();
+        ParseNode onNode = conditionExtractor.getJoinConditionParseNode();
         TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, 
subquery);
         JoinType joinType = isTopNode ? JoinType.Inner : JoinType.Left;
         ParseNode ret = NODE_FACTORY.comparison(node.getFilterOp(), l.get(0), 
NODE_FACTORY.column(NODE_FACTORY.table(null, rhsTableAlias), 
selectNodes.get(0).getAlias(), null));
@@ -428,7 +554,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
         SubqueryParseNode subqueryNode = (SubqueryParseNode) firstChild;
         SelectStatement subquery = 
fixSubqueryStatement(subqueryNode.getSelectNode());
         String rhsTableAlias = ParseNodeFactory.createTempAlias();
-        JoinConditionExtractor conditionExtractor = new 
JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias);
+        JoinConditionExtractor conditionExtractor = new 
JoinConditionExtractor(subquery, columnResolver, connection, rhsTableAlias);
         ParseNode where = subquery.getWhere() == null ? null : 
subquery.getWhere().accept(conditionExtractor);
         if (where == subquery.getWhere()) { // non-correlated any/all 
comparison subquery
             return l;
@@ -457,7 +583,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
             rhsNode = 
NODE_FACTORY.function(DistinctValueAggregateFunction.NAME, 
Collections.singletonList(rhsNode));
         }
         
-        List<AliasedNode> additionalSelectNodes = 
conditionExtractor.getAdditionalSelectNodes();
+        List<AliasedNode> additionalSelectNodes = 
conditionExtractor.getAdditionalSubselectSelectAliasedNodes();
         List<AliasedNode> selectNodes = 
Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1);        
         
selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), 
rhsNode));
         selectNodes.addAll(additionalSelectNodes);
@@ -489,7 +615,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
                     Collections.<SelectStatement> emptyList(), 
subquery.getUdfParseNodes());
         }
         
-        ParseNode onNode = conditionExtractor.getJoinCondition();
+        ParseNode onNode = conditionExtractor.getJoinConditionParseNode();
         TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, 
subquery);
         JoinType joinType = isTopNode ? JoinType.Inner : JoinType.Left;
         tableNode = NODE_FACTORY.join(joinType, tableNode, rhsTable, onNode, 
false);
@@ -623,8 +749,8 @@ public class SubqueryRewriter extends ParseNodeRewriter {
     private static class JoinConditionExtractor extends 
AndRewriterBooleanParseNodeVisitor {
         private final TableName tableName;
         private ColumnResolveVisitor columnResolveVisitor;
-        private List<AliasedNode> additionalSelectNodes;
-        private List<ParseNode> joinConditions;
+        private List<AliasedNode> additionalSubselectSelectAliasedNodes;
+        private List<ParseNode> joinConditionParseNodes;
         
         public JoinConditionExtractor(SelectStatement subquery, ColumnResolver 
outerResolver, 
                 PhoenixConnection connection, String tableAlias) throws 
SQLException {
@@ -632,22 +758,24 @@ public class SubqueryRewriter extends ParseNodeRewriter {
             this.tableName = NODE_FACTORY.table(null, tableAlias);
             ColumnResolver localResolver = 
FromCompiler.getResolverForQuery(subquery, connection);
             this.columnResolveVisitor = new 
ColumnResolveVisitor(localResolver, outerResolver);
-            this.additionalSelectNodes = Lists.<AliasedNode> newArrayList();
-            this.joinConditions = Lists.<ParseNode> newArrayList();
+            this.additionalSubselectSelectAliasedNodes = 
Lists.<AliasedNode>newArrayList();
+            this.joinConditionParseNodes = Lists.<ParseNode>newArrayList();
         }
         
-        public List<AliasedNode> getAdditionalSelectNodes() {
-            return this.additionalSelectNodes;
+        public List<AliasedNode> getAdditionalSubselectSelectAliasedNodes() {
+            return this.additionalSubselectSelectAliasedNodes;
         }
         
-        public ParseNode getJoinCondition() {
-            if (this.joinConditions.isEmpty())
+        public ParseNode getJoinConditionParseNode() {
+            if (this.joinConditionParseNodes.isEmpty()) {
                 return null;
-            
-            if (this.joinConditions.size() == 1)
-                return this.joinConditions.get(0);
-            
-            return NODE_FACTORY.and(this.joinConditions);            
+            }
+
+            if (this.joinConditionParseNodes.size() == 1) {
+                return this.joinConditionParseNodes.get(0);
+            }
+
+            return NODE_FACTORY.and(this.joinConditionParseNodes);
         }
 
         @Override
@@ -680,16 +808,18 @@ public class SubqueryRewriter extends ParseNodeRewriter {
             }
             if (lhsType == ColumnResolveVisitor.ColumnResolveType.LOCAL && 
rhsType == ColumnResolveVisitor.ColumnResolveType.OUTER) {
                 String alias = ParseNodeFactory.createTempAlias();
-                this.additionalSelectNodes.add(NODE_FACTORY.aliasedNode(alias, 
node.getLHS()));
+                this.additionalSubselectSelectAliasedNodes.add(
+                  NODE_FACTORY.aliasedNode(alias, node.getLHS()));
                 ParseNode lhsNode = NODE_FACTORY.column(tableName, alias, 
null);
-                this.joinConditions.add(NODE_FACTORY.equal(lhsNode, 
node.getRHS()));
+                this.joinConditionParseNodes.add(NODE_FACTORY.equal(lhsNode, 
node.getRHS()));
                 return null;
             }        
             if (lhsType == ColumnResolveVisitor.ColumnResolveType.OUTER && 
rhsType == ColumnResolveVisitor.ColumnResolveType.LOCAL) {
                 String alias = ParseNodeFactory.createTempAlias();
-                this.additionalSelectNodes.add(NODE_FACTORY.aliasedNode(alias, 
node.getRHS()));
+                this.additionalSubselectSelectAliasedNodes.add(
+                  NODE_FACTORY.aliasedNode(alias, node.getRHS()));
                 ParseNode rhsNode = NODE_FACTORY.column(tableName, alias, 
null);
-                this.joinConditions.add(NODE_FACTORY.equal(node.getLHS(), 
rhsNode));
+                
this.joinConditionParseNodes.add(NODE_FACTORY.equal(node.getLHS(), rhsNode));
                 return null;
             }
             
diff --git 
a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java 
b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
index 17c369d..6714cee 100644
--- 
a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
+++ 
b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
@@ -6549,7 +6549,7 @@ public class QueryCompilerTest extends 
BaseConnectionlessQueryTest {
 
             }
 
-          //test Correlated subquery with AggregateFunction with groupBy and 
is ORed part of the where clause.
+            //test Correlated subquery with AggregateFunction with groupBy and 
is ORed part of the where clause.
             sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
i.item_id IN "+
                     "(SELECT max(item_id) FROM " + orderTableName + " o  where 
o.price = i.price group by o.customer_id) or i.discount1 > 10 ORDER BY name";
             queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
@@ -6759,7 +6759,8 @@ public class QueryCompilerTest extends 
BaseConnectionlessQueryTest {
            scanPlan=(ScanPlan)(hashJoinPlan.getDelegate());
            TestUtil.assertSelectStatement(
                    scanPlan.getStatement(),
-                   "SELECT A.AID FROM " + tableName1 + "  WHERE (AGE > (SELECT 
 MAX(CODE) FROM " + tableName2 + " C  WHERE C.BID >= 1 LIMIT 2) AND (AGE >= 11 
AND AGE <= 33)) ORDER BY A.AID");
+                   "SELECT A.AID FROM " + tableName1 +
+                   "  WHERE (AGE > (SELECT  MAX(CODE) FROM " + tableName2 + " 
C  WHERE C.BID >= 1 LIMIT 2) AND (AGE >= 11 AND AGE <= 33)) ORDER BY A.AID");
            subPlans = hashJoinPlan.getSubPlans();
            assertTrue(subPlans.length == 2);
            assertTrue(subPlans[0] instanceof WhereClauseSubPlan);
@@ -6788,4 +6789,117 @@ public class QueryCompilerTest extends 
BaseConnectionlessQueryTest {
             conn.close();
         }
     }
+
+    @Test
+    public void testExistsSubqueryBug6498() throws Exception {
+        Connection conn = null;
+        try {
+            conn = DriverManager.getConnection(getUrl());
+            String itemTableName = "item_table";
+            String sql ="create table " + itemTableName +
+                "   (item_id varchar not null primary key, " +
+                "    name varchar, " +
+                "    price integer, " +
+                "    discount1 integer, " +
+                "    discount2 integer, " +
+                "    supplier_id varchar, " +
+                "    description varchar)";
+            conn.createStatement().execute(sql);
+
+            String orderTableName = "order_table";
+            sql = "create table " + orderTableName +
+                "   (order_id varchar not null primary key, " +
+                "    customer_id varchar, " +
+                "    item_id varchar, " +
+                "    price integer, " +
+                "    quantity integer, " +
+                "    date timestamp)";
+            conn.createStatement().execute(sql);
+
+            //test simple Correlated subquery
+            ParseNodeFactory.setTempAliasCounterValue(0);
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                 "(SELECT 1 FROM " + orderTableName + " o  where o.price = 
i.price and o.quantity = 5 ) ORDER BY name";
+            QueryPlan queryPlan = 
TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            System.out.println(queryPlan.getStatement());
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  Semi JOIN " +
+                    "(SELECT DISTINCT 1 $3,O.PRICE $2 FROM ORDER_TABLE O  
WHERE O.QUANTITY = 5) $1 "+
+                    "ON ($1.$2 = I.PRICE) ORDER BY NAME");
+
+            //test Correlated subquery with AggregateFunction and groupBy
+            ParseNodeFactory.setTempAliasCounterValue(0);
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                 "(SELECT 1 FROM " + orderTableName + " o  where o.item_id = 
i.item_id group by customer_id having count(order_id) > 1) " +
+                 "ORDER BY name";
+            queryPlan = TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  Semi JOIN " +
+                    "(SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM ORDER_TABLE O  
GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING  COUNT(ORDER_ID) > 1) $1 " +
+                    "ON ($1.$2 = I.ITEM_ID) ORDER BY NAME");
+
+            //for Correlated subquery, the extracted join condition must be 
equal expression.
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                    "(SELECT 1 FROM " + orderTableName + " o  where o.price = 
i.price or o.quantity > 1 group by o.customer_id) ORDER BY name";
+            try {
+                queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+                fail();
+            } catch(SQLFeatureNotSupportedException exception) {
+
+            }
+
+            //test Correlated subquery with AggregateFunction with groupBy and 
is ORed part of the where clause.
+            ParseNodeFactory.setTempAliasCounterValue(0);
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                 "(SELECT 1 FROM " + orderTableName + " o  where o.item_id = 
i.item_id group by customer_id having count(order_id) > 1) "+
+                 " or i.discount1 > 10 ORDER BY name";
+            queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  Left JOIN " +
+                    "(SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM ORDER_TABLE O  
GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING  COUNT(ORDER_ID) > 1) $1 " +
+                    "ON ($1.$2 = I.ITEM_ID) WHERE ($1.$3 IS NOT NULL  OR 
I.DISCOUNT1 > 10) ORDER BY NAME");
+
+            // test NonCorrelated subquery
+            ParseNodeFactory.setTempAliasCounterValue(0);
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                    "(SELECT 1 FROM " + orderTableName + " o  where o.price > 
8) ORDER BY name";
+            queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            System.out.println(queryPlan.getStatement());
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  WHERE  EXISTS 
(SELECT 1 FROM ORDER_TABLE O  WHERE O.PRICE > 8 LIMIT 1) ORDER BY NAME");
+
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                 "(SELECT 1 FROM " + orderTableName + " o  where o.price > 8 
group by o.customer_id,o.item_id having count(order_id) > 1)" +
+                 " ORDER BY name";
+            queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  WHERE  EXISTS "+
+                    "(SELECT 1 FROM ORDER_TABLE O  WHERE O.PRICE > 8 GROUP BY 
O.CUSTOMER_ID,O.ITEM_ID HAVING  COUNT(ORDER_ID) > 1 LIMIT 1)" +
+                    " ORDER BY NAME");
+
+            sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE 
exists "+
+                 "(SELECT 1 FROM " + orderTableName + " o  where o.price > 8 
group by o.customer_id,o.item_id having count(order_id) > 1)" +
+                 " or i.discount1 > 10 ORDER BY name";
+            queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+            assertTrue(queryPlan instanceof HashJoinPlan);
+            TestUtil.assertSelectStatement(
+                    queryPlan.getStatement(),
+                    "SELECT ITEM_ID,NAME FROM ITEM_TABLE I  WHERE " +
+                    "( EXISTS (SELECT 1 FROM ORDER_TABLE O  WHERE O.PRICE > 8 
GROUP BY O.CUSTOMER_ID,O.ITEM_ID HAVING  COUNT(ORDER_ID) > 1 LIMIT 1)" +
+                    " OR I.DISCOUNT1 > 10) ORDER BY NAME");
+        } finally {
+            conn.close();
+        }
+    }
+
 }

Reply via email to