This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY_4_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 74f9283223b02cf1d3f330a5ae6856d34d1ad0ec
Author: Eric Milles <eric.mil...@thomsonreuters.com>
AuthorDate: Tue Apr 30 10:33:49 2024 -0500

    GROOVY-11364: STC: propagate receiver generics to candidate return types
    
    4_0_X backport
---
 .../transform/stc/StaticTypeCheckingVisitor.java   |  12 +-
 src/test/groovy/transform/stc/LambdaTest.groovy    | 307 +++++++++++----------
 .../transform/stc/MethodReferenceTest.groovy       |  23 ++
 3 files changed, 197 insertions(+), 145 deletions(-)

diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 8d8ab524ec..707b5c054f 100644
--- 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -2509,7 +2509,17 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
                 int nParameters = candidates.stream().mapToInt(m -> 
m.getParameters().length).reduce((i,j) -> i == j ? i : -1).getAsInt();
                 Map<GenericsTypeName, GenericsType> gts = 
GenericsUtils.extractPlaceholders(receiverType);
                 
stubMissingTypeVariables(receiverType.redirect().getGenericsTypes(), gts); // 
GROOVY-11241
-                candidates.stream().map(candidate -> applyGenericsContext(gts, 
candidate.getReturnType()))
+                ClassNode ownerType = receiverType;
+                candidates.stream()
+                        .map(candidate -> {
+                            ClassNode returnType = candidate.getReturnType();
+                            if (!candidate.isStatic() && 
GenericsUtils.hasUnresolvedGenerics(returnType)) {
+                                Map<GenericsTypeName, GenericsType> spec = new 
HashMap<>(); // GROOVY-11364
+                                extractGenericsConnections(spec, ownerType, 
candidate.getDeclaringClass());
+                                returnType = applyGenericsContext(spec, 
returnType);
+                            }
+                            return returnType;
+                        })
                         
.reduce(WideningCategories::lowestUpperBound).ifPresent(returnType -> {
                             ClassNode closureType = 
wrapClosureType(returnType);
                             closureType.putNodeMetaData(CLOSURE_ARGUMENTS, 
nParameters); // GROOVY-10714
diff --git a/src/test/groovy/transform/stc/LambdaTest.groovy 
b/src/test/groovy/transform/stc/LambdaTest.groovy
index 974827c2d9..909928391b 100644
--- a/src/test/groovy/transform/stc/LambdaTest.groovy
+++ b/src/test/groovy/transform/stc/LambdaTest.groovy
@@ -18,6 +18,7 @@
  */
 package groovy.transform.stc
 
+import groovy.test.NotYetImplemented
 import org.junit.Test
 
 import static groovy.test.GroovyAssert.assertScript
@@ -38,7 +39,7 @@ final class LambdaTest {
     void testFunction() {
         assertScript shell, '''
             def f() {
-                [1, 2, 3].stream().map(e -> e + 1).collect(Collectors.toList())
+                [1, 2, 3].stream().map(i -> i + 1).toList()
             }
             assert f() == [2, 3, 4]
         '''
@@ -48,37 +49,61 @@ final class LambdaTest {
     void testFunction2() {
         assertScript shell, '''
             def f() {
-                [1, 2, 3].stream().map(e -> 
e.plus(1)).collect(Collectors.toList())
+                [1, 2, 3].stream().map(i -> i.plus(1)).toList()
             }
             assert f() == [2, 3, 4]
         '''
     }
 
     @Test
-    void testFunctionWithTypeArgument() {
+    void testFunction3() {
         assertScript shell, '''
             def f() {
-                [1, 2, 3].stream().<String>map(i -> 
null).collect(Collectors.toList())
+                [1, 2, 3].stream().<String>map(i -> null).toList()
             }
             assert f() == [null, null, null]
         '''
     }
 
+    // GROOVY-11364
+    @Test
+    void testFunction4() {
+        assertScript shell, '''
+            abstract class A<N extends Number> {
+                protected N process(N n) { n }
+            }
+
+            class C extends A<Integer> {
+                static void consume(Optional<Integer> option) {
+                    def result = option.orElse(null)
+                    assert result instanceof Integer
+                    assert result == 42
+                }
+                void test() {
+                    consume(Optional.of(42).map(i -> process(i)))
+                }
+            }
+
+            new C().test()
+        '''
+    }
+
     @Test
     void testBinaryOperator() {
         assertScript shell, '''
             def f() {
-                [1, 2, 3].stream().reduce(7, (Integer r, Integer e) -> r + e)
+                [1, 2, 3].stream().reduce(7, (Integer r, Integer i) -> r + i)
             }
             assert f() == 13
         '''
     }
 
-    @Test // GROOVY-8917
+    // GROOVY-8917
+    @Test
     void testBinaryOperatorWithoutExplicitTypes() {
         assertScript shell, '''
             def f() {
-                [1, 2, 3].stream().reduce(7, (r, e) -> r + e)
+                [1, 2, 3].stream().reduce(7, (r, i) -> r + i)
             }
             assert f() == 13
         '''
@@ -88,14 +113,15 @@ final class LambdaTest {
     void testBinaryOperatorWithoutExplicitTypes2() {
         assertScript shell, '''
             def f() {
-                BinaryOperator<Integer> accumulator = (r, e) -> r + e
+                BinaryOperator<Integer> accumulator = (r, i) -> r + i
                 return [1, 2, 3].stream().reduce(7, accumulator)
             }
             assert f() == 13
         '''
     }
 
-    @Test // GROOVY-10282
+    // GROOVY-10282
+    @Test
     void testBiFunctionAndBinaryOperatorWithSharedTypeParameter() {
         assertScript shell, '''
             def f() {
@@ -186,7 +212,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-10372
+    // GROOVY-10372
+    @Test
     void testComparator2() {
         def err = shouldFail shell, '''
             class T {
@@ -196,7 +223,8 @@ final class LambdaTest {
         assert err =~ /Expected type java.lang.Integer for lambda parameter: b/
     }
 
-    @Test // GROOVY-9977
+    // GROOVY-9977
+    @Test
     void testComparator3() {
         assertScript shell, '''
             class T {
@@ -214,7 +242,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9997
+    // GROOVY-9997
+    @Test
     void testComparator4() {
         assertScript '''
             @groovy.transform.TypeChecked
@@ -253,7 +282,8 @@ final class LambdaTest {
         }
     }
 
-    @Test // GROOVY-11304
+    // GROOVY-11304
+    @Test
     void testCollectors3() {
         assertScript shell, '''
             List<String> list = ['a', 'b', 'c']
@@ -617,7 +647,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9347
+    // GROOVY-9347
+    @Test
     void testConsumer7() {
         assertScript shell, '''
             void test() {
@@ -631,7 +662,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9340
+    // GROOVY-9340
+    @Test
     void testConsumer8() {
         assertScript shell, '''
             class Test1 {
@@ -754,7 +786,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9881
+    // GROOVY-9881
+    @Test
     void testFunctionalInterface4() {
         assertScript shell, '''
             class Value<V> {
@@ -778,8 +811,32 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-10372
+    // GROOVY-11121
+    @NotYetImplemented @Test
     void testFunctionalInterface5() {
+        String clazz = '''
+            class C<T> {
+                public String which
+                static <T> C<T> of(T value) { new C<>(which: 'T') }
+                static <T> C<T> of(Iterable<T> values) { new C<>(which: 
'Iterable<T>') }
+            }
+        '''
+        assertScript shell, clazz + '''
+            assert C.<IntUnaryOperator>of( (int i) -> i + 1 ).which == 'T'
+        '''
+        assertScript shell, clazz + '''
+            @groovy.transform.CompileDynamic
+            void p() {
+                assert C.of( (int i) -> i + 1 ).which == 'T'
+            }
+
+            p()
+        '''
+    }
+
+    // GROOVY-10372
+    @Test
+    void testFunctionalInterface6() {
         def err = shouldFail shell, '''
             interface I {
                 def m(List<String> strings)
@@ -790,8 +847,9 @@ final class LambdaTest {
         assert err =~ /Expected type java.util.List<java.lang.String> for 
lambda parameter: list/
     }
 
-    @Test // GROOVY-11013
-    void testFunctionalInterface6() {
+    // GROOVY-11013
+    @Test
+    void testFunctionalInterface7() {
         assertScript shell, '''
             interface I<T> {
                 def m(List<T> list_of_t)
@@ -801,8 +859,9 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-11072
-    void testFunctionalInterface7() {
+    // GROOVY-11072
+    @Test
+    void testFunctionalInterface8() {
         assertScript shell, '''
             class Model {
             }
@@ -828,167 +887,107 @@ final class LambdaTest {
         '''
     }
 
-    @Test
-    void testFunctionWithUpdatingLocalVariable() {
+    // GROOVY-11092
+    @NotYetImplemented @Test
+    void testFunctionalInterface9() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
-
-                static void p() {
-                    int i = 1
-                    assert [2, 4, 7] == [1, 2, 3].stream().map(e -> i += 
e).collect(Collectors.toList())
-                    assert 7 == i
-                }
+            Function<List<String>,String> f = (one,two) -> { one + two }
+            assert f.apply(['foo','bar']) == 'foobar'
+        '''
+        assertScript shell, '''
+            void setFun(Function<List<String>,String> f) {
+                assert f.apply(['foo','bar']) == 'foobar'
             }
+            fun = (one,two) -> { one + two }
+        '''
+        assertScript shell, '''
+            void test(Function<List<String>,String> f) {
+                assert f.apply(['foo','bar']) == 'foobar'
+            }
+            test((one, two) -> { one + two })
         '''
     }
 
     @Test
-    void testFunctionWithUpdatingLocalVariable2() {
-        assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    new Test1().p()
-                }
-
-                void p() {
+    void testFunctionWithUpdatingLocalVariable() {
+        for (mode in ['','static']) {
+            assertScript shell, """
+                $mode void p() {
                     int i = 1
-                    assert [2, 4, 7] == [1, 2, 3].stream().map(e -> i += 
e).collect(Collectors.toList())
-                    assert 7 == i
+                    Object result =  [1, 2, 3].stream().map(e -> i += 
e).collect(Collectors.toList())
+                    assert result == [2, 4, 7]
+                    assert i == 7
                 }
-            }
-        '''
+
+                p()
+            """
+        }
     }
 
     @Test
     void testFunctionWithVariableDeclaration() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
-
-                public static void p() {
-                    Function<Integer, String> f = (Integer e) -> 'a' + e
-                    assert ['a1', 'a2', 'a3'] == [1, 2, 
3].stream().map(f).collect(Collectors.toList())
-                }
-            }
+            Function<Integer, String> f = (Integer e) -> 'a' + e
+            assert ['a1', 'a2', 'a3'] == [1, 2, 
3].stream().map(f).collect(Collectors.toList())
         '''
     }
 
     @Test
     void testFunctionWithMixingVariableDeclarationAndMethodInvocation() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
+            String x = '#'
+            Integer y = 23
+            assert [1, 2, 3].stream().map(e -> '' + y + x + 
e).collect(Collectors.toList()) == ['23#1', '23#2', '23#3']
 
-                static void p() {
-                    String x = '#'
-                    Integer y = 23
-                    assert ['23#1', '23#2', '23#3'] == [1, 2, 
3].stream().map(e -> '' + y + x + e).collect(Collectors.toList())
-
-                    Function<Integer, String> f = (Integer e) -> 'a' + e
-                    assert ['a1', 'a2', 'a3'] == [1, 2, 
3].stream().map(f).collect(Collectors.toList())
-
-                    assert [2, 3, 4] == [1, 2, 3].stream().map(e -> 
e.plus(1)).collect(Collectors.toList());
-                }
-            }
+            Function<Integer, String> f = (Integer e) -> 'a' + e
+            assert [1, 2, 3].stream().map(f).collect(Collectors.toList()) == 
['a1', 'a2', 'a3']
+            assert [1, 2, 3].stream().map(e -> 
e.plus(1)).collect(Collectors.toList()) == [2, 3, 4]
         '''
     }
 
     @Test
     void testFunctionWithNestedLambda() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
-
-                static void p() {
-                    [1, 2].stream().forEach(e -> {
-                        def list = ['a', 'b'].stream().map(f -> f + e).toList()
-                        if (1 == e) {
-                            assert ['a1', 'b1'] == list
-                        } else if (2 == e) {
-                            assert ['a2', 'b2'] == list
-                        }
-                    })
-                }
-            }
+            [1, 2].stream().forEach(e -> {
+                def list = ['a', 'b'].stream().map(f -> f + e).toList()
+                assert list == (e == 1 ? ['a1', 'b1'] : ['a2', 'b2'])
+            })
         '''
     }
 
     @Test
     void testFunctionWithNestedLambda2() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
+            def list = ['a', 'b'].stream()
+            .map(e -> {
+                [1, 2].stream().map(f -> e + f).toList()
+            }).toList()
 
-                static void p() {
-                    def list = ['a', 'b'].stream()
-                    .map(e -> {
-                        [1, 2].stream().map(f -> e + f).toList()
-                    }).toList()
-
-                    assert ['a1', 'a2'] == list[0]
-                    assert ['b1', 'b2'] == list[1]
-                }
-            }
+            assert list[0] == ['a1', 'a2']
+            assert list[1] == ['b1', 'b2']
         '''
     }
 
     @Test
     void testFunctionWithNestedLambda3() {
         assertScript shell, '''
-            class Test1 {
-                static main(args) {
-                    p()
-                }
+            def list = ['a', 'b'].stream()
+            .map(e -> {
+                Function<Integer, String> x = (Integer f) -> e + f
+                [1, 2].stream().map(x).toList()
+            }).toList()
 
-                static void p() {
-                    def list = ['a', 'b'].stream()
-                    .map(e -> {
-                        Function<Integer, String> x = (Integer f) -> e + f
-                        [1, 2].stream().map(x).toList()
-                    }).toList()
-
-                    assert ['a1', 'a2'] == list[0]
-                    assert ['b1', 'b2'] == list[1]
-                }
-            }
+            assert list[0] == ['a1', 'a2']
+            assert list[1] == ['b1', 'b2']
         '''
     }
 
     @Test
     void testMixingLambdaAndMethodReference() {
         assertScript shell, '''
-            assert ['1', '2', '3'] == [1, 2, 
3].stream().map(Object::toString).collect(Collectors.toList())
-            assert [2, 3, 4] == [1, 2, 3].stream().map(e -> 
e.plus(1)).collect(Collectors.toList())
-            assert ['1', '2', '3'] == [1, 2, 
3].stream().map(Object::toString).collect(Collectors.toList())
-        '''
-    }
-
-    @Test
-    void testInitializeBlocks() {
-        assertScript shell, '''
-            class Test1 {
-                static sl
-                def il
-                static { sl = [1, 2, 3].stream().map(e -> e + 1).toList() }
-
-                {
-                    il = [1, 2, 3].stream().map(e -> e + 2).toList()
-                }
-            }
-
-            assert [2, 3, 4] == Test1.sl
-            assert [3, 4, 5] == new Test1().il
+            assert [1, 2, 
3].stream().map(Object::toString).collect(Collectors.toList()) == ['1', '2', 
'3']
+            assert [1, 2, 3].stream().map( e -> e.plus(1) 
).collect(Collectors.toList()) == [2, 3, 4]
+            assert [1, 2, 
3].stream().map(Object::toString).collect(Collectors.toList()) == ['1', '2', 
'3']
         '''
     }
 
@@ -1012,7 +1011,26 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9332
+    @Test
+    void testInitializeBlocks() {
+        assertScript shell, '''
+            class Test1 {
+                static sl
+                def il
+                static { sl = [1, 2, 3].stream().map(e -> e + 1).toList() }
+
+                {
+                    il = [1, 2, 3].stream().map(e -> e + 2).toList()
+                }
+            }
+
+            assert [2, 3, 4] == Test1.sl
+            assert [3, 4, 5] == new Test1().il
+        '''
+    }
+
+    // GROOVY-9332
+    @Test
     void testStaticInitializeBlocks1() {
         assertScript shell, '''
             class Test1 {
@@ -1025,7 +1043,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9347
+    // GROOVY-9347
+    @Test
     void testStaticInitializeBlocks2() {
         assertScript shell, '''
             class Test1 {
@@ -1037,7 +1056,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9342
+    // GROOVY-9342
+    @Test
     void testStaticInitializeBlocks3() {
         assertScript shell, '''
             class Test1 {
@@ -1157,7 +1177,6 @@ final class LambdaTest {
 
             test()
         '''
-
         assert err.message.contains('$Lambda')
     }
 
@@ -1420,7 +1439,6 @@ final class LambdaTest {
                 }
             }
         '''
-
         assert err.message.contains('tests.lambda.C')
     }
 
@@ -1483,7 +1501,6 @@ final class LambdaTest {
                 }
             }
         '''
-
         assert err.message.contains('tests.lambda.C')
     }
 
@@ -1761,7 +1778,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9146
+    // GROOVY-9146
+    @Test
     void testScriptWithExistingMainCS() {
         assertScript shell, '''
             static void main(args) {
@@ -1771,7 +1789,8 @@ final class LambdaTest {
         '''
     }
 
-    @Test // GROOVY-9770
+    // GROOVY-9770
+    @Test
     void testLambdaClassIsntSynthetic() {
         assertScript shell, '''
             class Foo {
diff --git a/src/test/groovy/transform/stc/MethodReferenceTest.groovy 
b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
index f9c2004357..f189fb6b91 100644
--- a/src/test/groovy/transform/stc/MethodReferenceTest.groovy
+++ b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
@@ -620,6 +620,29 @@ final class MethodReferenceTest {
         '''
     }
 
+    @Test // instance::instanceMethod -- GROOVY-11364
+    void testFunctionII5() {
+        assertScript shell, '''
+            abstract class A<N extends Number> {
+                protected N process(N n) { n }
+            }
+
+            @CompileStatic
+            class C extends A<Integer> {
+                static void consume(Optional<Integer> option) {
+                    def result = option.orElse(null)
+                    assert result instanceof Integer
+                    assert result == 42
+                }
+                void test() {
+                    consume(Optional.of(42).map(this::process))
+                }
+            }
+
+            new C().test()
+        '''
+    }
+
     @Test // instance::instanceMethod -- GROOVY-10057
     void testPredicateII() {
         assertScript shell, '''

Reply via email to