Author: Maciej Fijalkowski <[email protected]>
Branch: 
Changeset: r79720:a1a333d5337a
Date: 2015-09-20 10:43 +0200
http://bitbucket.org/pypy/pypy/changeset/a1a333d5337a/

Log:    Merge remember-tracing-counts. It does not quite remember the
        tracing counts, but does expose enough interface to do it yourself.
        Also reenables the jit hooks

diff --git a/pypy/doc/jit-hooks.rst b/pypy/doc/jit-hooks.rst
--- a/pypy/doc/jit-hooks.rst
+++ b/pypy/doc/jit-hooks.rst
@@ -5,19 +5,8 @@
 understanding what's pypy's JIT doing while running your program. There
 are three functions related to that coming from the ``pypyjit`` module:
 
-.. function:: set_optimize_hook(callable)
 
-    Set a compiling hook that will be called each time a loop is optimized,
-    but before assembler compilation. This allows adding additional
-    optimizations on Python level.
-
-    The callable will be called with the ``pypyjit.JitLoopInfo`` object.
-    Refer to it's documentation for details.
-
-    Result value will be the resulting list of operations, or None
-
-
-.. function:: set_compile_hook(callable)
+.. function:: set_compile_hook(callable, operations=True)
 
     Set a compiling hook that will be called each time a loop is compiled.
 
@@ -28,6 +17,9 @@
     inside the jit hook is itself jitted, it will get compiled, but the
     jit hook won't be called for that.
 
+    if operations=False, no list of operations will be available. Useful
+    if the hook is supposed to be very lighweight.
+
 .. function:: set_abort_hook(hook)
 
     Set a hook (callable) that will be called each time there is tracing
@@ -66,3 +58,25 @@
 
     * ``loop_run_times`` - counters for number of times loops are run, only
       works when ``enable_debug`` is called.
+
+.. class:: JitLoopInfo
+
+   A class containing information about the compiled loop. Usable attributes:
+
+   * ``operations`` - list of operations, if requested
+
+   * ``jitdriver_name`` - the name of jitdriver associated with this loop
+
+   * ``greenkey`` - a key at which the loop got compiled (e.g. code position,
+     is_being_profiled, pycode tuple for python jitdriver)
+
+   * ``loop_no`` - loop cardinal number
+
+   * ``bridge_no`` - id of the fail descr
+
+   * ``type`` - "entry bridge", "loop" or "bridge"
+
+   * ``asmaddr`` - an address in raw memory where assembler resides
+
+   * ``asmlen`` - length of raw memory with assembler associated
+
diff --git a/pypy/goal/targetpypystandalone.py 
b/pypy/goal/targetpypystandalone.py
--- a/pypy/goal/targetpypystandalone.py
+++ b/pypy/goal/targetpypystandalone.py
@@ -341,8 +341,8 @@
 
     def jitpolicy(self, driver):
         from pypy.module.pypyjit.policy import PyPyJitPolicy
-        #from pypy.module.pypyjit.hooks import pypy_hooks
-        return PyPyJitPolicy()#pypy_hooks)
+        from pypy.module.pypyjit.hooks import pypy_hooks
+        return PyPyJitPolicy(pypy_hooks)
 
     def get_entry_point(self, config):
         from pypy.tool.lib_pypy import import_from_lib_pypy
diff --git a/pypy/module/pypyjit/__init__.py b/pypy/module/pypyjit/__init__.py
--- a/pypy/module/pypyjit/__init__.py
+++ b/pypy/module/pypyjit/__init__.py
@@ -8,16 +8,18 @@
         'set_param':    'interp_jit.set_param',
         'residual_call': 'interp_jit.residual_call',
         'not_from_assembler': 'interp_jit.W_NotFromAssembler',
-        #'set_compile_hook': 'interp_resop.set_compile_hook',
-        #'set_optimize_hook': 'interp_resop.set_optimize_hook',
-        #'set_abort_hook': 'interp_resop.set_abort_hook',
-        #'get_stats_snapshot': 'interp_resop.get_stats_snapshot',
-        #'enable_debug': 'interp_resop.enable_debug',
-        #'disable_debug': 'interp_resop.disable_debug',
-        #'ResOperation': 'interp_resop.WrappedOp',
-        #'DebugMergePoint': 'interp_resop.DebugMergePoint',
-        #'JitLoopInfo': 'interp_resop.W_JitLoopInfo',
-        #'Box': 'interp_resop.WrappedBox',
+        'get_jitcell_at_key': 'interp_jit.get_jitcell_at_key',
+        'dont_trace_here': 'interp_jit.dont_trace_here',
+        'trace_next_iteration': 'interp_jit.trace_next_iteration',
+        'trace_next_iteration_hash': 'interp_jit.trace_next_iteration_hash',
+        'set_compile_hook': 'interp_resop.set_compile_hook',
+        'set_abort_hook': 'interp_resop.set_abort_hook',
+        'get_stats_snapshot': 'interp_resop.get_stats_snapshot',
+        'enable_debug': 'interp_resop.enable_debug',
+        'disable_debug': 'interp_resop.disable_debug',
+        'ResOperation': 'interp_resop.WrappedOp',
+        'DebugMergePoint': 'interp_resop.DebugMergePoint',
+        'JitLoopInfo': 'interp_resop.W_JitLoopInfo',
         'PARAMETER_DOCS': 'space.wrap(rpython.rlib.jit.PARAMETER_DOCS)',
     }
 
diff --git a/pypy/module/pypyjit/hooks.py b/pypy/module/pypyjit/hooks.py
--- a/pypy/module/pypyjit/hooks.py
+++ b/pypy/module/pypyjit/hooks.py
@@ -35,10 +35,10 @@
         self._compile_hook(debug_info, is_bridge=True)
 
     def before_compile(self, debug_info):
-        self._optimize_hook(debug_info, is_bridge=False)
+        pass
 
     def before_compile_bridge(self, debug_info):
-        self._optimize_hook(debug_info, is_bridge=True)
+        pass
 
     def _compile_hook(self, debug_info, is_bridge):
         space = self.space
@@ -46,7 +46,8 @@
         if cache.in_recursion:
             return
         if space.is_true(cache.w_compile_hook):
-            w_debug_info = W_JitLoopInfo(space, debug_info, is_bridge)
+            w_debug_info = W_JitLoopInfo(space, debug_info, is_bridge,
+                                         cache.compile_hook_with_ops)
             cache.in_recursion = True
             try:
                 try:
@@ -57,33 +58,4 @@
             finally:
                 cache.in_recursion = False
 
-    def _optimize_hook(self, debug_info, is_bridge=False):
-        space = self.space
-        cache = space.fromcache(Cache)
-        if cache.in_recursion:
-            return
-        if space.is_true(cache.w_optimize_hook):
-            w_debug_info = W_JitLoopInfo(space, debug_info, is_bridge)
-            cache.in_recursion = True
-            try:
-                try:
-                    w_res = space.call_function(cache.w_optimize_hook,
-                                                space.wrap(w_debug_info))
-                    if space.is_w(w_res, space.w_None):
-                        return
-                    l = []
-                    for w_item in space.listview(w_res):
-                        item = space.interp_w(WrappedOp, w_item)
-                        l.append(jit_hooks._cast_to_resop(item.op))
-                    del debug_info.operations[:] # modifying operations above 
is
-                    # probably not a great idea since types may not work
-                    # and we'll end up with half-working list and
-                    # a segfault/fatal RPython error
-                    for elem in l:
-                        debug_info.operations.append(elem)
-                except OperationError, e:
-                    e.write_unraisable(space, "jit hook ", 
cache.w_compile_hook)
-            finally:
-                cache.in_recursion = False
-
 pypy_hooks = PyPyJitIface()
diff --git a/pypy/module/pypyjit/interp_jit.py 
b/pypy/module/pypyjit/interp_jit.py
--- a/pypy/module/pypyjit/interp_jit.py
+++ b/pypy/module/pypyjit/interp_jit.py
@@ -5,11 +5,14 @@
 
 from rpython.rlib.rarithmetic import r_uint, intmask
 from rpython.rlib.jit import JitDriver, hint, we_are_jitted, dont_look_inside
-from rpython.rlib import jit
-from rpython.rlib.jit import current_trace_length, unroll_parameters
+from rpython.rlib import jit, jit_hooks
+from rpython.rlib.jit import current_trace_length, unroll_parameters,\
+     JitHookInterface
+from rpython.rtyper.annlowlevel import cast_instance_to_gcref
 import pypy.interpreter.pyopcode   # for side-effects
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.pycode import CO_GENERATOR, PyCode
+from pypy.interpreter.gateway import unwrap_spec
 from pypy.interpreter.pyframe import PyFrame
 from pypy.interpreter.pyopcode import ExitFrame, Yield
 from pypy.interpreter.baseobjspace import W_Root
@@ -188,3 +191,100 @@
     __call__ = interp2app(W_NotFromAssembler.descr_call),
 )
 W_NotFromAssembler.typedef.acceptable_as_base_class = False
+
+@unwrap_spec(next_instr=int, is_being_profiled=bool, w_pycode=PyCode)
+@dont_look_inside
+def get_jitcell_at_key(space, next_instr, is_being_profiled, w_pycode):
+    ll_pycode = cast_instance_to_gcref(w_pycode)
+    return space.wrap(bool(jit_hooks.get_jitcell_at_key(
+        'pypyjit', r_uint(next_instr), int(is_being_profiled), ll_pycode)))
+
+@unwrap_spec(next_instr=int, is_being_profiled=bool, w_pycode=PyCode)
+@dont_look_inside
+def dont_trace_here(space, next_instr, is_being_profiled, w_pycode):
+    ll_pycode = cast_instance_to_gcref(w_pycode)
+    jit_hooks.dont_trace_here(
+        'pypyjit', r_uint(next_instr), int(is_being_profiled), ll_pycode)
+    return space.w_None
+
+@unwrap_spec(next_instr=int, is_being_profiled=bool, w_pycode=PyCode)
+@dont_look_inside
+def trace_next_iteration(space, next_instr, is_being_profiled, w_pycode):
+    ll_pycode = cast_instance_to_gcref(w_pycode)
+    jit_hooks.trace_next_iteration(
+        'pypyjit', r_uint(next_instr), int(is_being_profiled), ll_pycode)
+    return space.w_None
+
+@unwrap_spec(hash=r_uint)
+@dont_look_inside
+def trace_next_iteration_hash(space, hash):
+    jit_hooks.trace_next_iteration_hash('pypyjit', hash)
+    return space.w_None
+
+# class Cache(object):
+#     in_recursion = False
+
+#     def __init__(self, space):
+#         self.w_compile_bridge = space.w_None
+#         self.w_compile_loop = space.w_None
+
+# def set_compile_bridge(space, w_hook):
+#     cache = space.fromcache(Cache)
+#     assert w_hook is not None
+#     cache.w_compile_bridge = w_hook
+
+# def set_compile_loop(space, w_hook):
+#     from rpython.rlib.nonconst import NonConstant
+    
+#     cache = space.fromcache(Cache)
+#     assert w_hook is not None
+#     cache.w_compile_loop = w_hook
+#     cache.in_recursion = NonConstant(False)
+
+# class PyPyJitHookInterface(JitHookInterface):
+#     def after_compile(self, debug_info):
+#         space = self.space
+#         cache = space.fromcache(Cache)
+#         if cache.in_recursion:
+#             return
+#         l_w = []
+#         if not space.is_true(cache.w_compile_loop):
+#             return
+#         for i, op in enumerate(debug_info.operations):
+#             if op.is_guard():
+#                 w_t = space.newtuple([space.wrap(i), 
space.wrap(op.getopnum()), space.wrap(op.getdescr().get_jitcounter_hash())])
+#                 l_w.append(w_t)
+#         try:
+#             cache.in_recursion = True
+#             try:
+#                 space.call_function(cache.w_compile_loop, space.newlist(l_w))
+#             except OperationError, e:
+#                 e.write_unraisable(space, "jit hook ", 
cache.w_compile_bridge)
+#         finally:
+#             cache.in_recursion = False
+
+#     def after_compile_bridge(self, debug_info):
+#         space = self.space
+#         cache = space.fromcache(Cache)
+#         if cache.in_recursion:
+#             return
+#         if not space.is_true(cache.w_compile_bridge):
+#             return
+#         w_hash = space.wrap(debug_info.fail_descr.get_jitcounter_hash())
+#         try:
+#             cache.in_recursion = True
+#             try:
+#                 space.call_function(cache.w_compile_bridge, w_hash)
+#             except OperationError, e:
+#                 e.write_unraisable(space, "jit hook ", 
cache.w_compile_bridge)
+#         finally:
+#             cache.in_recursion = False
+
+#     def before_compile(self, debug_info):
+#         pass
+
+#     def before_compile_bridge(self, debug_info):
+#         pass
+
+# pypy_hooks = PyPyJitHookInterface()
+
diff --git a/pypy/module/pypyjit/interp_resop.py 
b/pypy/module/pypyjit/interp_resop.py
--- a/pypy/module/pypyjit/interp_resop.py
+++ b/pypy/module/pypyjit/interp_resop.py
@@ -22,7 +22,6 @@
     def __init__(self, space):
         self.w_compile_hook = space.w_None
         self.w_abort_hook = space.w_None
-        self.w_optimize_hook = space.w_None
 
     def getno(self):
         self.no += 1
@@ -43,8 +42,9 @@
     else:
         return space.wrap(greenkey_repr)
 
-def set_compile_hook(space, w_hook):
-    """ set_compile_hook(hook)
+@unwrap_spec(operations=bool)
+def set_compile_hook(space, w_hook, operations=True):
+    """ set_compile_hook(hook, operations=True)
 
     Set a compiling hook that will be called each time a loop is compiled.
 
@@ -58,25 +58,9 @@
     cache = space.fromcache(Cache)
     assert w_hook is not None
     cache.w_compile_hook = w_hook
+    cache.compile_hook_with_ops = operations
     cache.in_recursion = NonConstant(False)
 
-def set_optimize_hook(space, w_hook):
-    """ set_optimize_hook(hook)
-
-    Set a compiling hook that will be called each time a loop is optimized,
-    but before assembler compilation. This allows adding additional
-    optimizations on Python level.
-
-    The hook will be called with the pypyjit.JitLoopInfo object. Refer to it's
-    docstring for details.
-
-    Result value will be the resulting list of operations, or None
-    """
-    cache = space.fromcache(Cache)
-    cache.w_optimize_hook = w_hook
-    cache.in_recursion = NonConstant(False)
-
-
 def set_abort_hook(space, w_hook):
     """ set_abort_hook(hook)
 
@@ -96,6 +80,9 @@
     cache.in_recursion = NonConstant(False)
 
 def wrap_oplist(space, logops, operations, ops_offset=None):
+    # this function is called from the JIT
+    from rpython.jit.metainterp.resoperation import rop
+    
     l_w = []
     jitdrivers_sd = logops.metainterp_sd.jitdrivers_sd
     for op in operations:
@@ -103,117 +90,58 @@
             ofs = -1
         else:
             ofs = ops_offset.get(op, 0)
-        if op.opnum == rop.DEBUG_MERGE_POINT:
+        num = op.getopnum()
+        name = op.getopname()
+        if num == rop.DEBUG_MERGE_POINT:
             jd_sd = jitdrivers_sd[op.getarg(0).getint()]
             greenkey = op.getarglist()[3:]
             repr = jd_sd.warmstate.get_location_str(greenkey)
             w_greenkey = wrap_greenkey(space, jd_sd.jitdriver, greenkey, repr)
-            l_w.append(DebugMergePoint(space, jit_hooks._cast_to_gcref(op),
+            l_w.append(DebugMergePoint(space, name,
                                        logops.repr_of_resop(op),
                                        jd_sd.jitdriver.name,
                                        op.getarg(1).getint(),
                                        op.getarg(2).getint(),
                                        w_greenkey))
         else:
-            l_w.append(WrappedOp(jit_hooks._cast_to_gcref(op), ofs,
-                                 logops.repr_of_resop(op)))
+            l_w.append(WrappedOp(name, ofs, logops.repr_of_resop(op)))
     return l_w
 
+@unwrap_spec(offset=int, repr=str, name=str)
+def descr_new_resop(space, w_tp, name, offset=-1, repr=''):
+    return WrappedOp(name, offset, repr)
 
-class WrappedBox(W_Root):
-    """ A class representing a single box
-    """
-    def __init__(self, llbox):
-        self.llbox = llbox
-
-    def descr_getint(self, space):
-        if not jit_hooks.box_isint(self.llbox):
-            raise OperationError(space.w_NotImplementedError,
-                                 space.wrap("Box has no int value"))
-        return space.wrap(jit_hooks.box_getint(self.llbox))
-
-@unwrap_spec(no=int)
-def descr_new_box(space, w_tp, no):
-    return WrappedBox(jit_hooks.boxint_new(no))
-
-WrappedBox.typedef = TypeDef(
-    'Box',
-    __new__ = interp2app(descr_new_box),
-    getint = interp2app(WrappedBox.descr_getint),
-)
-
-@unwrap_spec(num=int, offset=int, repr=str, w_res=W_Root)
-def descr_new_resop(space, w_tp, num, w_args, w_res, offset=-1,
-                    repr=''):
-    args = [space.interp_w(WrappedBox, w_arg).llbox for w_arg in
-            space.listview(w_args)]
-    if space.is_none(w_res):
-        llres = jit_hooks.emptyval()
-    else:
-        if not isinstance(w_res, WrappedBox):
-            raise OperationError(space.w_TypeError, space.wrap(
-                "expected box type, got %s" % space.type(w_res)))
-        llres = w_res.llbox
-    return WrappedOp(jit_hooks.resop_new(num, args, llres), offset, repr)
-
-@unwrap_spec(repr=str, jd_name=str, call_depth=int, call_id=int)
-def descr_new_dmp(space, w_tp, w_args, repr, jd_name, call_depth, call_id,
+@unwrap_spec(repr=str, name=str, jd_name=str, call_depth=int, call_id=int)
+def descr_new_dmp(space, w_tp, name, repr, jd_name, call_depth, call_id,
     w_greenkey):
 
-    args = [space.interp_w(WrappedBox, w_arg).llbox for w_arg in
-            space.listview(w_args)]
-    num = rop.DEBUG_MERGE_POINT
-    return DebugMergePoint(space,
-                           jit_hooks.resop_new(num, args, 
jit_hooks.emptyval()),
+    return DebugMergePoint(space, name,
                            repr, jd_name, call_depth, call_id, w_greenkey)
 
 
 class WrappedOp(W_Root):
     """ A class representing a single ResOperation, wrapped nicely
     """
-    def __init__(self, op, offset, repr_of_resop):
-        self.op = op
+    def __init__(self, name, offset, repr_of_resop):
         self.offset = offset
+        self.name = name
         self.repr_of_resop = repr_of_resop
 
     def descr_repr(self, space):
         return space.wrap(self.repr_of_resop)
 
-    def descr_num(self, space):
-        return space.wrap(jit_hooks.resop_getopnum(self.op))
-
     def descr_name(self, space):
-        return space.wrap(hlstr(jit_hooks.resop_getopname(self.op)))
-
-    @unwrap_spec(no=int)
-    def descr_getarg(self, space, no):
-        try:
-            box = jit_hooks.resop_getarg(self.op, no)
-        except IndexError:
-            raise OperationError(space.w_IndexError,
-                                 space.wrap("Index out of range"))
-        return WrappedBox(box)
-
-    @unwrap_spec(no=int, w_box=WrappedBox)
-    def descr_setarg(self, space, no, w_box):
-        jit_hooks.resop_setarg(self.op, no, w_box.llbox)
-
-    def descr_getresult(self, space):
-        return WrappedBox(jit_hooks.resop_getresult(self.op))
-
-    def descr_setresult(self, space, w_box):
-        box = space.interp_w(WrappedBox, w_box)
-        jit_hooks.resop_setresult(self.op, box.llbox)
+        return space.wrap(self.name)
 
 class DebugMergePoint(WrappedOp):
     """ A class representing Debug Merge Point - the entry point
     to a jitted loop.
     """
 
-    def __init__(self, space, op, repr_of_resop, jd_name, call_depth, call_id,
-        w_greenkey):
+    def __init__(self, space, name, repr_of_resop, jd_name, call_depth,
+                 call_id, w_greenkey):
 
-        WrappedOp.__init__(self, op, -1, repr_of_resop)
+        WrappedOp.__init__(self, name, -1, repr_of_resop)
         self.jd_name = jd_name
         self.call_depth = call_depth
         self.call_id = call_id
@@ -237,12 +165,7 @@
     __doc__ = WrappedOp.__doc__,
     __new__ = interp2app(descr_new_resop),
     __repr__ = interp2app(WrappedOp.descr_repr),
-    num = GetSetProperty(WrappedOp.descr_num),
     name = GetSetProperty(WrappedOp.descr_name),
-    getarg = interp2app(WrappedOp.descr_getarg),
-    setarg = interp2app(WrappedOp.descr_setarg),
-    result = GetSetProperty(WrappedOp.descr_getresult,
-                            WrappedOp.descr_setresult),
     offset = interp_attrproperty("offset", cls=WrappedOp),
 )
 WrappedOp.typedef.acceptable_as_base_class = False
@@ -278,14 +201,18 @@
     asmaddr     = 0
     asmlen      = 0
 
-    def __init__(self, space, debug_info, is_bridge=False):
-        logops = debug_info.logger._make_log_operations()
-        if debug_info.asminfo is not None:
-            ofs = debug_info.asminfo.ops_offset
+    def __init__(self, space, debug_info, is_bridge=False, wrap_ops=True):
+        if wrap_ops:
+            memo = {}
+            logops = debug_info.logger._make_log_operations(memo)
+            if debug_info.asminfo is not None:
+                ofs = debug_info.asminfo.ops_offset
+            else:
+                ofs = {}
+            ops = debug_info.operations
+            self.w_ops = space.newlist(wrap_oplist(space, logops, ops, ofs))
         else:
-            ofs = {}
-        self.w_ops = space.newlist(
-            wrap_oplist(space, logops, debug_info.operations, ofs))
+            self.w_ops = space.w_None
 
         self.jd_name = debug_info.get_jitdriver().name
         self.type = debug_info.type
diff --git a/pypy/module/pypyjit/test/test_jit_hook.py 
b/pypy/module/pypyjit/test/test_jit_hook.py
--- a/pypy/module/pypyjit/test/test_jit_hook.py
+++ b/pypy/module/pypyjit/test/test_jit_hook.py
@@ -136,7 +136,6 @@
         assert dmp.call_id == 0
         assert dmp.offset == -1
         assert int_add.name == 'int_add'
-        assert int_add.num == self.int_add_num
         assert int_add.offset == 0
         self.on_compile_bridge()
         expected = ('<JitLoopInfo pypyjit, 4 operations, starting at '
@@ -173,10 +172,7 @@
         self.on_compile()
         loop = loops[0]
         op = loop.operations[2]
-        # Should not crash the interpreter
-        raises(IndexError, op.getarg, 2)
         assert op.name == 'guard_nonnull'
-        raises(NotImplementedError, op.getarg(0).getint)
 
     def test_non_reentrant(self):
         import pypyjit
@@ -234,35 +230,28 @@
         assert l == ['pypyjit']
 
     def test_creation(self):
-        from pypyjit import Box, ResOperation
+        from pypyjit import ResOperation
 
-        op = ResOperation(self.int_add_num, [Box(1), Box(3)], Box(4))
-        assert op.num == self.int_add_num
+        op = ResOperation("int_add", -1, "int_add(1, 2)")
         assert op.name == 'int_add'
-        box = op.getarg(0)
-        assert box.getint() == 1
-        box2 = op.result
-        assert box2.getint() == 4
-        op.setarg(0, box2)
-        assert op.getarg(0).getint() == 4
-        op.result = box
-        assert op.result.getint() == 1
+        assert repr(op) == "int_add(1, 2)"
 
     def test_creation_dmp(self):
-        from pypyjit import DebugMergePoint, Box
+        from pypyjit import DebugMergePoint
 
         def f():
             pass
 
-        op = DebugMergePoint([Box(0)], 'repr', 'pypyjit', 2, 3, (f.func_code, 
0, 0))
+        op = DebugMergePoint("debug_merge_point", 'repr', 'pypyjit', 2, 3, 
(f.func_code, 0, 0))
         assert op.bytecode_no == 0
         assert op.pycode is f.func_code
         assert repr(op) == 'repr'
         assert op.jitdriver_name == 'pypyjit'
-        assert op.num == self.dmp_num
+        assert op.name == 'debug_merge_point'
         assert op.call_depth == 2
         assert op.call_id == 3
-        op = DebugMergePoint([Box(0)], 'repr', 'notmain', 5, 4, ('str',))
+        op = DebugMergePoint('debug_merge_point', 'repr', 'notmain',
+                             5, 4, ('str',))
         raises(AttributeError, 'op.pycode')
         assert op.call_depth == 5
 
diff --git a/rpython/jit/metainterp/compile.py 
b/rpython/jit/metainterp/compile.py
--- a/rpython/jit/metainterp/compile.py
+++ b/rpython/jit/metainterp/compile.py
@@ -740,6 +740,9 @@
         metainterp.handle_guard_failure(self, deadframe)
     _trace_and_compile_from_bridge._dont_inline_ = True
 
+    def get_jitcounter_hash(self):
+        return self.status & self.ST_SHIFT_MASK
+
     def must_compile(self, deadframe, metainterp_sd, jitdriver_sd):
         jitcounter = metainterp_sd.warmrunnerdesc.jitcounter
         #
diff --git a/rpython/jit/metainterp/pyjitpl.py 
b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -1582,8 +1582,9 @@
                 resbox = self.metainterp.execute_and_record_varargs(
                     rop.CALL_MAY_FORCE_F, allboxes, descr=descr)
             elif tp == 'v':
-                resbox = self.metainterp.execute_and_record_varargs(
+                self.metainterp.execute_and_record_varargs(
                     rop.CALL_MAY_FORCE_N, allboxes, descr=descr)
+                resbox = None
             else:
                 assert False
             self.metainterp.vrefs_after_residual_call()
@@ -2961,6 +2962,8 @@
         opnum = OpHelpers.call_assembler_for_descr(op.getdescr())
         op = op.copy_and_change(opnum, args=args, descr=token)
         self.history.operations.append(op)
+        if opnum == rop.CALL_ASSEMBLER_N:
+            op = None
         #
         # To fix an obscure issue, make sure the vable stays alive
         # longer than the CALL_ASSEMBLER operation.  We do it by
diff --git a/rpython/jit/metainterp/test/test_jitiface.py 
b/rpython/jit/metainterp/test/test_jitiface.py
--- a/rpython/jit/metainterp/test/test_jitiface.py
+++ b/rpython/jit/metainterp/test/test_jitiface.py
@@ -1,12 +1,13 @@
 
 import py
-from rpython.rlib.jit import JitDriver, JitHookInterface, Counters
+from rpython.rlib.jit import JitDriver, JitHookInterface, Counters, 
dont_look_inside
 from rpython.rlib import jit_hooks
 from rpython.jit.metainterp.test.support import LLJitMixin
 from rpython.jit.codewriter.policy import JitPolicy
 from rpython.jit.metainterp.resoperation import rop
-from rpython.rtyper.annlowlevel import hlstr
+from rpython.rtyper.annlowlevel import hlstr, cast_instance_to_gcref
 from rpython.jit.metainterp.jitprof import Profiler, EmptyProfiler
+from rpython.jit.codewriter.policy import JitPolicy
 
 
 class JitHookInterfaceTests(object):
@@ -156,6 +157,127 @@
             assert jit_hooks.stats_get_times_value(None, Counters.TRACING) == 0
         self.meta_interp(main, [], ProfilerClass=EmptyProfiler)
 
+    def test_get_jitcell_at_key(self):
+        driver = JitDriver(greens = ['s'], reds = ['i'], name='jit')
+
+        def loop(i, s):
+            while i > s:
+                driver.jit_merge_point(i=i, s=s)
+                i -= 1
+
+        def main(s):
+            loop(30, s)
+            assert jit_hooks.get_jitcell_at_key("jit", s)
+            assert not jit_hooks.get_jitcell_at_key("jit", s + 1)
+            jit_hooks.trace_next_iteration("jit", s + 1)
+            loop(s + 3, s + 1)
+            assert jit_hooks.get_jitcell_at_key("jit", s + 1)
+
+        self.meta_interp(main, [5])
+        self.check_jitcell_token_count(2)
+
+    def test_get_jitcell_at_key_ptr(self):
+        driver = JitDriver(greens = ['s'], reds = ['i'], name='jit')
+
+        class Green(object):
+            pass
+
+        def loop(i, s):
+            while i > 0:
+                driver.jit_merge_point(i=i, s=s)
+                i -= 1
+
+        def main(s):
+            g1 = Green()
+            g2 = Green()
+            g1_ptr = cast_instance_to_gcref(g1)
+            g2_ptr = cast_instance_to_gcref(g2)
+            loop(10, g1)
+            assert jit_hooks.get_jitcell_at_key("jit", g1_ptr)
+            assert not jit_hooks.get_jitcell_at_key("jit", g2_ptr)
+            jit_hooks.trace_next_iteration("jit", g2_ptr)
+            loop(2, g2)
+            assert jit_hooks.get_jitcell_at_key("jit", g2_ptr)
+
+        self.meta_interp(main, [5])
+        self.check_jitcell_token_count(2)
+
+    def test_dont_trace_here(self):
+        driver = JitDriver(greens = ['s'], reds = ['i', 'k'], name='jit')
+
+        def loop(i, s):
+            k = 4
+            while i > 0:
+                driver.jit_merge_point(k=k, i=i, s=s)
+                if s == 1:
+                    loop(3, 0)
+                k -= 1
+                i -= 1
+                if k == 0:
+                    k = 4
+                    driver.can_enter_jit(k=k, i=i, s=s)
+
+        def main(s, check):
+            if check:
+                jit_hooks.dont_trace_here("jit", 0)
+            loop(30, s)
+
+        self.meta_interp(main, [1, 0], inline=True)
+        self.check_resops(call_assembler_n=0)
+        self.meta_interp(main, [1, 1], inline=True)
+        self.check_resops(call_assembler_n=8)
+
+    def test_trace_next_iteration_hash(self):
+        driver = JitDriver(greens = ['s'], reds = ['i'], name="name")
+        class Hashes(object):
+            check = False
+            
+            def __init__(self):
+                self.l = []
+                self.t = []
+
+        hashes = Hashes()
+
+        class Hooks(object):
+            def before_compile(self, debug_info):
+                pass
+
+            def after_compile(self, debug_info):
+                for op in debug_info.operations:
+                    if op.is_guard():
+                        hashes.l.append(op.getdescr().get_jitcounter_hash())
+
+            def before_compile_bridge(self, debug_info):
+                pass
+
+            def after_compile_bridge(self, debug_info):
+                hashes.t.append(debug_info.fail_descr.get_jitcounter_hash())
+
+        hooks = Hooks()
+
+        @dont_look_inside
+        def foo():
+            if hashes.l:
+                for item in hashes.l:
+                    jit_hooks.trace_next_iteration_hash("name", item)
+
+        def loop(i, s):
+            while i > 0:
+                driver.jit_merge_point(s=s, i=i)
+                foo()
+                if i == 3:
+                    i -= 1
+                i -= 1
+
+        def main(s, check):
+            hashes.check = check
+            loop(10, s)
+
+        self.meta_interp(main, [1, 0], policy=JitPolicy(hooks))
+        assert len(hashes.l) == 4
+        assert len(hashes.t) == 0
+        self.meta_interp(main, [1, 1], policy=JitPolicy(hooks))
+        assert len(hashes.t) == 1
 
 class LLJitHookInterfaceTests(JitHookInterfaceTests):
     # use this for any backend, instead of the super class
diff --git a/rpython/jit/metainterp/warmspot.py 
b/rpython/jit/metainterp/warmspot.py
--- a/rpython/jit/metainterp/warmspot.py
+++ b/rpython/jit/metainterp/warmspot.py
@@ -1,9 +1,9 @@
-import sys
+import sys, py
 
 from rpython.tool.sourcetools import func_with_new_name
 from rpython.rtyper.lltypesystem import lltype, llmemory
 from rpython.rtyper.annlowlevel import (llhelper, MixLevelHelperAnnotator,
-    cast_base_ptr_to_instance, hlstr)
+    cast_base_ptr_to_instance, hlstr, cast_instance_to_gcref)
 from rpython.rtyper.llannotation import lltype_to_annotation
 from rpython.annotator import model as annmodel
 from rpython.rtyper.llinterp import LLException
@@ -129,6 +129,17 @@
                     results.append((graph, block, i))
     return results
 
+def _find_jit_markers(graphs, marker_names):
+    results = []
+    for graph in graphs:
+        for block in graph.iterblocks():
+            for i in range(len(block.operations)):
+                op = block.operations[i]
+                if (op.opname == 'jit_marker' and
+                    op.args[0].value in marker_names):
+                    results.append((graph, block, i))
+    return results
+
 def find_can_enter_jit(graphs):
     return _find_jit_marker(graphs, 'can_enter_jit')
 
@@ -236,6 +247,7 @@
         self.rewrite_can_enter_jits()
         self.rewrite_set_param_and_get_stats()
         self.rewrite_force_virtual(vrefinfo)
+        self.rewrite_jitcell_accesses()
         self.rewrite_force_quasi_immutable()
         self.add_finish()
         self.metainterp_sd.finish_setup(self.codewriter)
@@ -598,6 +610,80 @@
         (_, jd._PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
             [llmemory.GCREF, llmemory.GCREF], ASMRESTYPE)
 
+    def rewrite_jitcell_accesses(self):
+        jitdrivers_by_name = {}
+        for jd in self.jitdrivers_sd:
+            name = jd.jitdriver.name
+            if name != 'jitdriver':
+                jitdrivers_by_name[name] = jd
+        m = _find_jit_markers(self.translator.graphs,
+                              ('get_jitcell_at_key', 'trace_next_iteration',
+                               'dont_trace_here', 'trace_next_iteration_hash'))
+        accessors = {}
+
+        def get_accessor(name, jitdriver_name, function, ARGS, green_arg_spec):
+            a = accessors.get((name, jitdriver_name))
+            if a:
+                return a
+            d = {'function': function,
+                 'cast_instance_to_gcref': cast_instance_to_gcref,
+                 'lltype': lltype}
+            arg_spec = ", ".join([("arg%d" % i) for i in range(len(ARGS))])
+            arg_converters = []
+            for i, spec in enumerate(green_arg_spec):
+                if isinstance(spec, lltype.Ptr):
+                    arg_converters.append("arg%d = 
lltype.cast_opaque_ptr(type%d, arg%d)" % (i, i, i))
+                    d['type%d' % i] = spec
+            convert = ";".join(arg_converters)
+            if name == 'get_jitcell_at_key':
+                exec py.code.Source("""
+                def accessor(%s):
+                    %s
+                    return cast_instance_to_gcref(function(%s))
+                """ % (arg_spec, convert, arg_spec)).compile() in d
+                FUNC = lltype.Ptr(lltype.FuncType(ARGS, llmemory.GCREF))
+            elif name == "trace_next_iteration_hash":
+                exec py.code.Source("""
+                def accessor(arg0):
+                    function(arg0)
+                """).compile() in d
+                FUNC = lltype.Ptr(lltype.FuncType([lltype.Unsigned],
+                                                  lltype.Void))
+            else:
+                exec py.code.Source("""
+                def accessor(%s):
+                    %s
+                    function(%s)
+                """ % (arg_spec, convert, arg_spec)).compile() in d
+                FUNC = lltype.Ptr(lltype.FuncType(ARGS, lltype.Void))
+            func = d['accessor']
+            ll_ptr = self.helper_func(FUNC, func)
+            accessors[(name, jitdriver_name)] = ll_ptr
+            return ll_ptr
+
+        for graph, block, index in m:
+            op = block.operations[index]
+            jitdriver_name = op.args[1].value
+            JitCell = jitdrivers_by_name[jitdriver_name].warmstate.JitCell
+            ARGS = [x.concretetype for x in op.args[2:]]
+            if op.args[0].value == 'get_jitcell_at_key':
+                func = JitCell.get_jitcell
+            elif op.args[0].value == 'dont_trace_here':
+                func = JitCell.dont_trace_here
+            elif op.args[0].value == 'trace_next_iteration_hash':
+                func = JitCell.trace_next_iteration_hash
+            else:
+                func = JitCell._trace_next_iteration
+            argspec = jitdrivers_by_name[jitdriver_name]._green_args_spec
+            accessor = get_accessor(op.args[0].value,
+                                    jitdriver_name, func,
+                                    ARGS, argspec)
+            v_result = op.result
+            c_accessor = Constant(accessor, concretetype=lltype.Void)
+            newop = SpaceOperation('direct_call', [c_accessor] + op.args[2:],
+                                   v_result)
+            block.operations[index] = newop
+
     def rewrite_can_enter_jits(self):
         sublists = {}
         for jd in self.jitdrivers_sd:
diff --git a/rpython/jit/metainterp/warmstate.py 
b/rpython/jit/metainterp/warmstate.py
--- a/rpython/jit/metainterp/warmstate.py
+++ b/rpython/jit/metainterp/warmstate.py
@@ -545,12 +545,24 @@
             @staticmethod
             def trace_next_iteration(greenkey):
                 greenargs = unwrap_greenkey(greenkey)
+                JitCell._trace_next_iteration(*greenargs)
+
+            @staticmethod
+            def _trace_next_iteration(*greenargs):
                 hash = JitCell.get_uhash(*greenargs)
                 jitcounter.change_current_fraction(hash, 0.98)
 
             @staticmethod
+            def trace_next_iteration_hash(hash):
+                jitcounter.change_current_fraction(hash, 0.98)
+
+            @staticmethod
             def ensure_jit_cell_at_key(greenkey):
                 greenargs = unwrap_greenkey(greenkey)
+                return JitCell._ensure_jit_cell_at_key(*greenargs)
+
+            @staticmethod
+            def _ensure_jit_cell_at_key(*greenargs):
                 hash = JitCell.get_uhash(*greenargs)
                 cell = jitcounter.lookup_chain(hash)
                 while cell is not None:
@@ -561,6 +573,11 @@
                 newcell = JitCell(*greenargs)
                 jitcounter.install_new_cell(hash, newcell)
                 return newcell
+
+            @staticmethod
+            def dont_trace_here(*greenargs):
+                cell = JitCell._ensure_jit_cell_at_key(*greenargs)
+                cell.flags |= JC_DONT_TRACE_HERE
         #
         self.JitCell = JitCell
         return JitCell
diff --git a/rpython/rlib/jit_hooks.py b/rpython/rlib/jit_hooks.py
--- a/rpython/rlib/jit_hooks.py
+++ b/rpython/rlib/jit_hooks.py
@@ -5,6 +5,7 @@
     cast_base_ptr_to_instance, llstr)
 from rpython.rtyper.extregistry import ExtRegistryEntry
 from rpython.rtyper.lltypesystem import llmemory, lltype
+from rpython.flowspace.model import Constant
 from rpython.rtyper import rclass
 
 
@@ -127,3 +128,33 @@
 @register_helper(lltype.Ptr(LOOP_RUN_CONTAINER))
 def stats_get_loop_run_times(warmrunnerdesc):
     return warmrunnerdesc.metainterp_sd.cpu.get_all_loop_runs()
+
+# ---------------------- jitcell interface ----------------------
+
+def _new_hook(name, resulttype):
+    def hook(name, *greenkey):
+        raise Exception("need to run translated")
+    hook.func_name = name
+
+    class GetJitCellEntry(ExtRegistryEntry):
+        _about_ = hook
+
+        def compute_result_annotation(self, s_name, *args_s):
+            assert s_name.is_constant()
+            return resulttype
+
+        def specialize_call(self, hop):
+            c_jitdriver = Constant(hop.args_s[0].const, 
concretetype=lltype.Void)
+            c_name = Constant(name, concretetype=lltype.Void)
+            hop.exception_cannot_occur()
+            args_v = [hop.inputarg(arg, arg=i + 1)
+                      for i, arg in enumerate(hop.args_r[1:])]
+            return hop.genop('jit_marker', [c_name, c_jitdriver] + args_v,
+                             resulttype=hop.r_result)
+
+    return hook
+
+get_jitcell_at_key = _new_hook('get_jitcell_at_key', SomePtr(llmemory.GCREF))
+trace_next_iteration = _new_hook('trace_next_iteration', None)
+dont_trace_here = _new_hook('dont_trace_here', None)
+trace_next_iteration_hash = _new_hook('trace_next_iteration_hash', None)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to