Author: Carl Friedrich Bolz-Tereick <[email protected]>
Branch: py3.7
Changeset: r98449:b7258d53fa17
Date: 2020-01-05 19:22 +0100
http://bitbucket.org/pypy/pypy/changeset/b7258d53fa17/

Log:    some more features. implement a visitor to replace annotations with
        constant strings

diff --git a/pypy/interpreter/astcompiler/test/test_unparse.py 
b/pypy/interpreter/astcompiler/test/test_unparse.py
--- a/pypy/interpreter/astcompiler/test/test_unparse.py
+++ b/pypy/interpreter/astcompiler/test/test_unparse.py
@@ -1,7 +1,7 @@
 from pypy.interpreter.pyparser import pyparse
 from pypy.interpreter.astcompiler.astbuilder import ast_from_node
 from pypy.interpreter.astcompiler import ast, consts
-from pypy.interpreter.astcompiler.unparse import unparse
+from pypy.interpreter.astcompiler.unparse import unparse, unparse_annotations
 
 
 class TestAstUnparser:
@@ -33,6 +33,11 @@
             ast = self.get_first_expr(unparsed)
             assert unparse(self.space, ast) == unparsed
 
+    def test_constant(self):
+        w_one = self.space.newint(1)
+        node = ast.Constant(w_one, 0, 0)
+        assert unparse(self.space, node) == "1"
+
     def test_num(self):
         self.check("1")
         self.check("1.64")
@@ -46,6 +51,11 @@
     def test_name(self):
         self.check('a')
 
+    def test_name_constant(self):
+        self.check('True')
+        self.check('False')
+        self.check('None')
+
     def test_unaryop(self):
         self.check('+a')
         self.check('-a')
@@ -96,7 +106,10 @@
         self.check('(a for x in y for z in b)')
 
     def test_set_comprehension(self):
-        self.check('{a for (x,) in y for z in b}')
+        self.check('{a for x, in y for z in b}')
+
+    def test_dict_comprehension(self):
+        self.check('{a: b for x in y}')
 
     def test_ellipsis(self):
         self.check('...')
@@ -105,6 +118,12 @@
         self.check('a[1]')
         self.check('a[1:5]')
         self.check('a[1:5,7:12,:,5]')
+        self.check('a[::1]')
+        self.check('dict[(str, int)]', 'dict[str, int]')
+
+    def test_attribute(self):
+        self.check('a.b.c')
+        self.check('1 .b')
 
     def test_yield(self):
         self.check('(yield)')
@@ -131,6 +150,27 @@
         self.check('lambda *, m, l=5: 1')
         self.check('lambda **foo: 1')
 
-    def test_fstring(self):
-        self.check('f"a{a + 2}b c{d}"')
 
+class TestAstUnparseAnnotations(object):
+    def setup_class(cls):
+        cls.parser = pyparse.PythonParser(cls.space)
+
+    def get_ast(self, source, p_mode="exec", flags=None):
+        if flags is None:
+            flags = consts.CO_FUTURE_WITH_STATEMENT
+        info = pyparse.CompileInfo("<test>", p_mode, flags)
+        tree = self.parser.parse_source(source, info)
+        ast_node = ast_from_node(self.space, tree, info, self.parser)
+        return ast_node
+
+    def test_function(self):
+        ast = self.get_ast("""def f(a: b) -> 1 + 2: return a + 12""")
+        func = ast.body[0]
+        res = unparse_annotations(self.space, func)
+        assert self.space.text_w(res.args.args[0].annotation.value) == "b"
+        assert self.space.text_w(res.returns.value) == "1 + 2"
+
+    def test_global(self):
+        ast = self.get_ast("""a: list[int]""")
+        res = unparse_annotations(self.space, ast)
+        assert self.space.text_w(res.body[0].annotation.value) == 'list[int]'
diff --git a/pypy/interpreter/astcompiler/unparse.py 
b/pypy/interpreter/astcompiler/unparse.py
--- a/pypy/interpreter/astcompiler/unparse.py
+++ b/pypy/interpreter/astcompiler/unparse.py
@@ -1,7 +1,9 @@
 from rpython.rlib.rutf8 import Utf8StringBuilder
-from pypy.interpreter.error import oefmt
+from rpython.rlib.objectmodel import specialize
+from pypy.interpreter.error import oefmt, OperationError
 from pypy.interpreter.astcompiler import ast
 
+
 PRIORITY_TUPLE = 0
 PRIORITY_TEST = 1                   # 'if'-'else', 'lambda'
 PRIORITY_OR = 2                     # 'or'
@@ -73,12 +75,20 @@
         return False
 
     def default_visitor(self, node):
-        raise oefmt(self.space.w_SystemError,
-                    "%T is not an expression", node)
+        raise OperationError(self.space.w_SystemError,
+                    self.space.newtext("expression type not supported yet" + 
str(node)))
 
     def visit_Ellipsis(self, node):
         self.append_ascii('...')
 
+    def visit_Constant(self, node):
+        w_str = self.space.str(node.value)
+        self.append_w_str(w_str)
+
+    def visit_NameConstant(self, node):
+        w_str = self.space.str(node.value)
+        self.append_w_str(w_str)
+
     def visit_Num(self, node):
         w_str = self.space.str(node.n)
         self.append_w_str(w_str)
@@ -235,14 +245,13 @@
         if node.elts is None:
             self.append_ascii("()")
             return
-        self.append_ascii("(")
-        for i, elt in enumerate(node.elts):
-            if i > 0:
-                self.append_ascii(", ")
-            self.append_expr(elt)
-        if len(node.elts) == 1:
-            self.append_ascii(",")
-        self.append_ascii(")")
+        with self.maybe_parenthesize(PRIORITY_TUPLE):
+            for i, elt in enumerate(node.elts):
+                if i > 0:
+                    self.append_ascii(", ")
+                self.append_expr(elt)
+            if len(node.elts) == 1:
+                self.append_ascii(",")
 
     def visit_Set(self, node):
         self.append_ascii("{")
@@ -272,6 +281,7 @@
 
     def append_generators(self, generators):
         for generator in generators:
+            assert isinstance(generator, ast.comprehension)
             if generator.is_async:
                 self.append_ascii(' async for ')
             else:
@@ -302,6 +312,14 @@
         self.append_generators(node.generators)
         self.append_ascii('}')
 
+    def visit_DictComp(self, node):
+        self.append_ascii('{')
+        self.append_expr(node.key)
+        self.append_ascii(': ')
+        self.append_expr(node.value)
+        self.append_generators(node.generators)
+        self.append_ascii('}')
+
     def visit_Subscript(self, node):
         self.append_expr(node.value, PRIORITY_ATOM)
         self.append_ascii('[')
@@ -309,7 +327,7 @@
         self.append_ascii(']')
 
     def visit_Index(self, node):
-        self.append_expr(node.value)
+        self.append_expr(node.value, PRIORITY_TUPLE)
 
     def visit_Slice(self, node):
         if node.lower:
@@ -319,7 +337,7 @@
             self.append_expr(node.upper)
         if node.step:
             self.append_ascii(':')
-            self.append_expr(node.upper)
+            self.append_expr(node.step)
 
     def visit_ExtSlice(self, node):
         for i, slice in enumerate(node.dims):
@@ -327,6 +345,17 @@
                 self.append_ascii(',')
             self.append_expr(slice)
 
+    def visit_Attribute(self, node):
+        value = node.value
+        self.append_expr(value, PRIORITY_ATOM)
+        if isinstance(value, ast.Num) and \
+                self.space.isinstance_w(value.n, self.space.w_int):
+            period = ' .'
+        else:
+            period = '.'
+        self.append_ascii(period)
+        self.append_utf8(node.attr)
+
     def visit_Yield(self, node):
         if node.value:
             self.append_ascii("(yield ")
@@ -358,6 +387,7 @@
         if node.keywords:
             for i, keyword in enumerate(node.keywords):
                 first = self.append_if_not_first(first, ', ')
+                assert isinstance(keyword, ast.keyword)
                 if keyword.arg is None:
                     self.append_ascii('**')
                 else:
@@ -417,7 +447,49 @@
                 self.append_ascii(': ')
             self.append_expr(node.body)
 
+
 def unparse(space, ast):
     visitor = UnparseVisitor(space)
     ast.walkabout(visitor)
     return visitor.builder.build()
+
+def w_unparse(space, ast):
+    visitor = UnparseVisitor(space)
+    ast.walkabout(visitor)
+    return space.newutf8(visitor.builder.build(), visitor.builder.getlength())
+
+class UnparseAnnotationsVisitor(ast.ASTVisitor):
+    def __init__(self, space):
+        self.space = space
+
+    @specialize.argtype(1)
+    def default_visitor(self, node):
+        return node
+
+    def unparse(self, node):
+        return ast.Constant(
+                    w_unparse(self.space, node),
+                    node.lineno,
+                    node.col_offset)
+
+    def visit_arg(self, node):
+        annotation = node.annotation
+        if annotation:
+            node.annotation = self.unparse(annotation)
+        return node
+
+    def visit_FunctionDef(self, node):
+        returns = node.returns
+        if returns:
+            node.returns = self.unparse(returns)
+        return node
+
+    def visit_AnnAssign(self, node):
+        node.annotation = self.unparse(node.annotation)
+        return node
+
+def unparse_annotations(space, ast):
+    visitor = UnparseAnnotationsVisitor(space)
+    return ast.mutate_over(visitor)
+
+
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to