This is an automated email from the ASF dual-hosted git repository.
sunlan pushed a commit to branch GROOVY_4_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git
The following commit(s) were added to refs/heads/GROOVY_4_0_X by this push:
new fe9486ad9e GROOVY-11720: [GINQ] Failed to recognize sub-query in where
clause (#2273)
fe9486ad9e is described below
commit fe9486ad9e7acb229753a83156680abbdfffbab1
Author: Daniel Sun <[email protected]>
AuthorDate: Mon Jul 28 04:29:58 2025 +0900
GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause (#2273)
(cherry picked from commit e58a9e1712725016dd2035bc9916ddeedef8d8e6)
---
.../org/apache/groovy/ginq/dsl/GinqAstBuilder.java | 31 ++-
.../ginq/provider/collection/GinqAstWalker.groovy | 38 +++-
.../test/org/apache/groovy/ginq/GinqTest.groovy | 253 +++++++++++++++++++++
3 files changed, 302 insertions(+), 20 deletions(-)
diff --git
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index f88b690c81..beb5e39220 100644
---
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -53,11 +53,13 @@ import org.codehaus.groovy.syntax.Types;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
+
/**
* Build the AST for GINQ
*
@@ -376,12 +378,17 @@ public class GinqAstBuilder extends CodeVisitorSupport
implements SyntaxErrorRep
public void visitBinaryExpression(BinaryExpression expression) {
super.visitBinaryExpression(expression);
- final int opType = expression.getOperation().getType();
- if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) {
- if (null != latestGinqExpression &&
isSelectMethodCallExpression(expression.getRightExpression())) {
+ final Integer opType = expression.getOperation().getType();
+ if (FILTER_BINARY_OP_SET.contains(opType)) {
+ if (null != latestGinqExpression) {
// use the nested ginq and clear it
- expression.setRightExpression(latestGinqExpression);
- latestGinqExpression = null;
+ if
(isSelectMethodCallExpression(expression.getRightExpression())) {
+ expression.setRightExpression(latestGinqExpression);
+ latestGinqExpression = null;
+ } else if
(isSelectMethodCallExpression(expression.getLeftExpression())) {
+ expression.setLeftExpression(latestGinqExpression);
+ latestGinqExpression = null;
+ }
}
}
}
@@ -464,8 +471,12 @@ public class GinqAstBuilder extends CodeVisitorSupport
implements SyntaxErrorRep
return sourceUnit;
}
- private static final String __LATEST_GINQ_EXPRESSION_CLAUSE =
"__latestGinqExpressionClause";
+ private static final Set<Integer> FILTER_BINARY_OP_SET =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+ Types.KEYWORD_IN, Types.COMPARE_NOT_IN, Types.COMPARE_IDENTICAL,
Types.COMPARE_NOT_IDENTICAL,
+ Types.COMPARE_EQUAL, Types.COMPARE_NOT_EQUAL, Types.COMPARE_LESS_THAN,
Types.COMPARE_LESS_THAN_EQUAL,
+ Types.COMPARE_GREATER_THAN, Types.COMPARE_GREATER_THAN_EQUAL,
Types.MATCH_REGEX)));
+ private static final String __LATEST_GINQ_EXPRESSION_CLAUSE =
"__latestGinqExpressionClause";
private static final String KW_WITH = "with"; // reserved keyword
private static final String KW_FROM = "from";
private static final String KW_IN = "in";
@@ -483,10 +494,12 @@ public class GinqAstBuilder extends CodeVisitorSupport
implements SyntaxErrorRep
private static final String KW_OVER = "over";
private static final String KW_AS = "as";
private static final String KW_SHUTDOWN = "shutdown";
- private static final Set<String> KEYWORD_SET = new HashSet<>();
+ private static final Set<String> KEYWORD_SET;
static {
- KEYWORD_SET.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON,
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
+ Set<String> keywordSet = new HashSet<>();
+ keywordSet.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON,
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
KW_LIMIT, KW_OFFSET, KW_SELECT,
KW_DISTINCT, KW_WITHINGROUP, KW_OVER, KW_AS, KW_SHUTDOWN));
- KEYWORD_SET.addAll(JoinExpression.JOIN_NAME_LIST);
+ keywordSet.addAll(JoinExpression.JOIN_NAME_LIST);
+ KEYWORD_SET = Collections.unmodifiableSet(keywordSet);
}
}
diff --git
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 66bd5ecc88..cf7b879041 100644
---
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -263,7 +263,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>,
SyntaxErrorReportable
if (expr instanceof MethodCallExpression) {
MethodCallExpression call = (MethodCallExpression) expr
- if (call.implicitThis &&
AGG_FUNCTION_NAME_LIST.contains(call.methodAsString)) {
+ if (call.implicitThis &&
AGG_FUNCTION_NAME_SET.contains(call.methodAsString)) {
def argumentCnt = ((ArgumentListExpression)
call.getArguments()).getExpressions().size()
if (1 == argumentCnt || (FUNCTION_COUNT == call.methodAsString
&& 0 == argumentCnt)) {
return true
@@ -542,12 +542,19 @@ class GinqAstWalker implements
GinqAstVisitor<Expression>, SyntaxErrorReportable
}
if (expression instanceof BinaryExpression) {
- if (expression.operation.type in [Types.KEYWORD_IN,
Types.COMPARE_NOT_IN]) {
- if (expression.rightExpression instanceof
AbstractGinqExpression) {
- expression.rightExpression =
callX(visit((AbstractGinqExpression) expression.rightExpression), "toList")
- return expression
- }
+ boolean containsGinqExpression = false
+ if (expression.leftExpression instanceof
AbstractGinqExpression) {
+ expression.leftExpression =
callSingleValue((AbstractGinqExpression) expression.leftExpression)
+ containsGinqExpression = true
+ }
+ if (expression.rightExpression instanceof
AbstractGinqExpression) {
+ expression.rightExpression = expression.operation.type
in IN_OP_SET
+ ? callToList((AbstractGinqExpression)
expression.rightExpression)
+ : callSingleValue((AbstractGinqExpression)
expression.rightExpression)
+ containsGinqExpression = true
}
+
+ if (containsGinqExpression) return expression
}
return expression.transformExpression(this)
@@ -560,6 +567,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>,
SyntaxErrorReportable
return whereMethodCallExpression
}
+ private MethodCallExpression callSingleValue(AbstractGinqExpression
expression) {
+ return callX(classX(QUERYABLE_HELPER_TYPE), "singleValue",
visit(expression))
+ }
+
+ private MethodCallExpression callToList(AbstractGinqExpression expression)
{
+ return callX(visit(expression), "toList")
+ }
+
@Override
MethodCallExpression visitGroupExpression(GroupExpression groupExpression)
{
DataSourceExpression dataSourceExpression =
groupExpression.dataSourceExpression
@@ -743,7 +758,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>,
SyntaxErrorReportable
def windowFunctionMethodCallExpression =
(MethodCallExpression) expression.objectExpression
Expression result = null
- if
(windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_LIST) {
+ if
(windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_SET) {
def argumentListExpression =
(ArgumentListExpression) windowFunctionMethodCallExpression.arguments
List<Expression> argumentExpressionList = []
if
(windowFunctionMethodCallExpression.methodAsString !in [FUNCTION_ROW_NUMBER,
FUNCTION_RANK, FUNCTION_DENSE_RANK, FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST]
&& argumentListExpression.expressions) {
@@ -1304,7 +1319,7 @@ class GinqAstWalker implements
GinqAstVisitor<Expression>, SyntaxErrorReportable
if (FUNCTION_COUNT == methodName && ((TupleExpression)
expression.arguments).getExpressions().isEmpty()) { // Similar to count(*) in
SQL
expression.objectExpression = varX(__GROUP)
transformedExpression = expression
- } else if (methodName in AGG_FUNCTION_NAME_LIST) {
+ } else if (methodName in AGG_FUNCTION_NAME_SET) {
Expression lambdaCode = ((TupleExpression)
expression.arguments).getExpression(0)
lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME,
findRootObjectExpression(lambdaCode).text)
transformedExpression =
@@ -1562,7 +1577,7 @@ class GinqAstWalker implements
GinqAstVisitor<Expression>, SyntaxErrorReportable
private static final String FUNCTION_VAR = 'var'
private static final String FUNCTION_VARP = 'varp'
private static final String FUNCTION_AGG = 'agg'
- private static final List<String> AGG_FUNCTION_NAME_LIST =
[FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG,
FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP,
FUNCTION_LIST, FUNCTION_AGG]
+ private static final Set<String> AGG_FUNCTION_NAME_SET = [FUNCTION_COUNT,
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN,
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_LIST,
FUNCTION_AGG] as HashSet
private static final String FUNCTION_ROW_NUMBER = 'rowNumber'
private static final String FUNCTION_LEAD = 'lead'
@@ -1575,9 +1590,10 @@ class GinqAstWalker implements
GinqAstVisitor<Expression>, SyntaxErrorReportable
private static final String FUNCTION_PERCENT_RANK = 'percentRank'
private static final String FUNCTION_CUME_DIST = 'cumeDist'
private static final String FUNCTION_NTILE = 'ntile'
- private static final List<String> WINDOW_FUNCTION_LIST = [FUNCTION_COUNT,
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN,
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG,
-
FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE,
FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK,
FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE]
+ private static final Set<String> WINDOW_FUNCTION_SET = [FUNCTION_COUNT,
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN,
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG,
+
FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE,
FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK,
FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE] as HashSet
+ private static final Set<Integer> IN_OP_SET = [Types.KEYWORD_IN,
Types.COMPARE_NOT_IN] as HashSet
private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name
private static final String USE_WINDOW_FUNCTION = 'useWindowFunction'
diff --git
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index ea382d959b..662d4ae042 100644
---
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -769,6 +769,259 @@ class GinqTest {
'''
}
+ @Test
+ void "testGinq - nested from where select - 0"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [1, 2, 3]
+ where n == (from m in [2] select m)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 1"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [1, 2, 3]
+ where n == (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 1 - swap operand"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) == n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 2"() {
+ assertGinqScript '''
+ assert [1, 3] == GQ {
+ from n in [1, 2, 3]
+ where n != (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 2 - swap operand"() {
+ assertGinqScript '''
+ assert [1, 3] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) != n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 3"() {
+ assertGinqScript '''
+ assert [3] == GQ {
+ from n in [1, 2, 3]
+ where n > (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 3 - swap operand"() {
+ assertGinqScript '''
+ assert [3] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) < n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 4"() {
+ assertGinqScript '''
+ assert [2, 3] == GQ {
+ from n in [1, 2, 3]
+ where n >= (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 4 - swap operand"() {
+ assertGinqScript '''
+ assert [2, 3] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) <= n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 5"() {
+ assertGinqScript '''
+ assert [1] == GQ {
+ from n in [1, 2, 3]
+ where n < (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 5 - swap operand"() {
+ assertGinqScript '''
+ assert [1] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) > n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 6"() {
+ assertGinqScript '''
+ assert [1, 2] == GQ {
+ from n in [1, 2, 3]
+ where n <= (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 6 - swap operand"() {
+ assertGinqScript '''
+ assert [1, 2] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) >= n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 7"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [1, 2, 3]
+ where n === (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 7 - swap operand"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) === n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 8"() {
+ assertGinqScript '''
+ assert [1, 3] == GQ {
+ from n in [1, 2, 3]
+ where n !== (from m in [1, 2] select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 8 - swap operand"() {
+ assertGinqScript '''
+ assert [1, 3] == GQ {
+ from n in [1, 2, 3]
+ where ((from m in [1, 2] select max(m)) !== n)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 9"() {
+ assertGinqScript '''
+ assert ['123'] == GQ {
+ from n in ['abc', '123', 'a1b2c3']
+ where n ==~ (from m in [/[0-9]+/] select m)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 10"() {
+ assertGinqScript '''
+ assert [/[a-z]+/] == GQ {
+ from n in [/[0-9]+/, /[a-z]+/]
+ where (from m in ['abc'] select m) ==~ n
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 11"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [2, 3, 4]
+ where 2 > (from m in [1, 2, 3] where m < n select max(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 11 - swap operand"() {
+ assertGinqScript '''
+ assert [2] == GQ {
+ from n in [2, 3, 4]
+ where ((from m in [1, 2, 3] where m < n select max(m)) < 2)
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 12"() {
+ assertGinqScript '''
+ assert [2, 3] == GQ {
+ from n in [2, 3, 4]
+ where n == (from m in [1, 2, 3] where m >= n select min(m))
+ select n
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - nested from where select - 12 - swap operand"() {
+ assertGinqScript '''
+ assert [2, 3] == GQ {
+ from n in [2, 3, 4]
+ where ((from m in [1, 2, 3] where m >= n select min(m)) == n)
+ select n
+ }.toList()
+ '''
+ }
+
@Test
void "testGinq - nested from select - 0"() {
assertGinqScript '''