Author: Armin Rigo <[email protected]>
Branch: conditional_call_value
Changeset: r79454:f010addba075
Date: 2015-09-05 12:33 +0200
http://bitbucket.org/pypy/pypy/changeset/f010addba075/

Log:    in-progress: tweak the user API

diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -1106,67 +1106,64 @@
         return hop.genop('jit_record_known_class', [v_inst, v_cls],
                          resulttype=lltype.Void)
 
-def _jit_conditional_call(condition, function, *args):
-    pass
 
[email protected]_location()
-def conditional_call(condition, function, *args):
+def conditional_call(condition, function, *args, **kwds):
+    default = kwds.pop('default', None)
+    assert not kwds
+    if condition:
+        return function(*args)
+    return default
+
+def _ll_cond_call(condition, ll_default, ll_function, *ll_args):
     if we_are_jitted():
-        _jit_conditional_call(condition, function, *args)
+        from rpython.rtyper.lltypesystem import lltype
+        from rpython.rtyper.lltypesystem.lloperation import llop
+        RESTYPE = lltype.typeOf(ll_default)
+        return llop.jit_conditional_call(RESTYPE, condition, ll_default,
+                                         ll_function, *ll_args)
     else:
         if condition:
-            return function(*args)
-conditional_call._always_inline_ = True
+            return ll_function(*ll_args)
+        return ll_default
+_ll_cond_call._always_inline_ = True
 
 class ConditionalCallEntry(ExtRegistryEntry):
-    _about_ = _jit_conditional_call
+    _about_ = conditional_call
 
-    def compute_result_annotation(self, *args_s):
-        self.bookkeeper.emulate_pbc_call(self.bookkeeper.position_key,
-                                         args_s[1], args_s[2:])
+    def compute_result_annotation(self, *args_s, **kwds_s):
+        from rpython.annotator import model as annmodel
 
-    def specialize_call(self, hop):
+        s_res = self.bookkeeper.emulate_pbc_call(self.bookkeeper.position_key,
+                                                 args_s[1], args_s[2:])
+        if 's_default' in kwds_s:
+            assert kwds_s.keys() == ['s_default']
+            return annmodel.unionof(s_res, kwds_s['s_default'])
+        else:
+            assert not kwds_s
+            return None
+
+    def specialize_call(self, hop, i_default=None):
         from rpython.rtyper.lltypesystem import lltype
 
-        args_v = hop.inputargs(lltype.Bool, lltype.Void, *hop.args_r[2:])
+        end = len(hop.args_r) - (i_default is not None)
+        inputargs = [lltype.Bool, lltype.Void] + hop.args_r[2:end]
+        if i_default is not None:
+            assert i_default == end
+            inputargs.append(hop.r_result)
+
+        args_v = hop.inputargs(*inputargs)
         args_v[1] = hop.args_r[1].get_concrete_llfn(hop.args_s[1],
-                                                    hop.args_s[2:], 
hop.spaceop)
+                                                    hop.args_s[2:end],
+                                                    hop.spaceop)
+        if i_default is not None:
+            v_default = args_v.pop()
+        else:
+            v_default = hop.inputconst(lltype.Void, None)
+        args_v.insert(1, v_default)
+
         hop.exception_is_here()
-        return hop.genop('jit_conditional_call', args_v)
+        return hop.gendirectcall(_ll_cond_call, *args_v)
 
-def _jit_conditional_call_value(condition, function, default_value, *args):
-    return default_value
-
[email protected]_location()
-def conditional_call_value(condition, function, default_value, *args):
-    if we_are_jitted():
-        return _jit_conditional_call_value(condition, function, default_value,
-                                           *args)
-    else:
-        if condition:
-            return function(*args)
-        return default_value
-conditional_call._always_inline_ = True
-
-class ConditionalCallValueEntry(ExtRegistryEntry):
-    _about_ = _jit_conditional_call_value
-
-    def compute_result_annotation(self, *args_s):
-        s_result = self.bookkeeper.emulate_pbc_call(
-            self.bookkeeper.position_key, args_s[1], args_s[3:],
-            callback = self.bookkeeper.position_key)
-        return s_result
-
-    def specialize_call(self, hop):
-        from rpython.rtyper.lltypesystem import lltype
-
-        args_v = hop.inputargs(lltype.Bool, lltype.Void, *hop.args_r[2:])
-        args_v[1] = hop.args_r[1].get_concrete_llfn(hop.args_s[1],
-                                                    hop.args_s[3:], 
hop.spaceop)
-        hop.exception_is_here()
-        resulttype = hop.r_result
-        return hop.genop('jit_conditional_call_value', args_v,
-                         resulttype=resulttype)
 
 class Counters(object):
     counters="""
diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py
--- a/rpython/rlib/test/test_jit.py
+++ b/rpython/rlib/test/test_jit.py
@@ -300,3 +300,38 @@
         mix = MixLevelHelperAnnotator(t.rtyper)
         mix.getgraph(later, [annmodel.s_Bool], annmodel.s_None)
         mix.finish()
+
+    def test_conditional_call_value(self):
+        def g(m):
+            return m + 42
+        def f(n, m):
+            return conditional_call(n >= 0, g, m, default=678)
+
+        res = self.interpret(f, [10, 20])
+        assert res == 20 + 42
+        res = self.interpret(f, [-10, 20])
+        assert res == 678
+
+    def test_conditional_call_void(self):
+        class X:
+            pass
+        glob = X()
+        #
+        def g(m):
+            glob.x += m
+        #
+        def h():
+            glob.x += 2
+        #
+        def f(n, m):
+            glob.x = 0
+            conditional_call(n >= 0, g, m)
+            conditional_call(n >= 5, h)
+            return glob.x
+
+        res = self.interpret(f, [10, 20])
+        assert res == 22
+        res = self.interpret(f, [2, 20])
+        assert res == 20
+        res = self.interpret(f, [-2, 20])
+        assert res == 0
diff --git a/rpython/rtyper/llinterp.py b/rpython/rtyper/llinterp.py
--- a/rpython/rtyper/llinterp.py
+++ b/rpython/rtyper/llinterp.py
@@ -548,9 +548,6 @@
     def op_jit_conditional_call(self, *args):
         raise NotImplementedError("should not be called while not jitted")
 
-    def op_jit_conditional_call_value(self, *args):
-        raise NotImplementedError("should not be called while not jitted")
-
     def op_get_exception_addr(self, *args):
         raise NotImplementedError
 
diff --git a/rpython/rtyper/lltypesystem/lloperation.py 
b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -451,8 +451,7 @@
     'jit_force_quasi_immutable': LLOp(canrun=True),
     'jit_record_known_class'  : LLOp(canrun=True),
     'jit_ffi_save_result':  LLOp(canrun=True),
-    'jit_conditional_call':       LLOp(),
-    'jit_conditional_call_value': LLOp(),
+    'jit_conditional_call': LLOp(),
     'get_exception_addr':   LLOp(),
     'get_exc_value_addr':   LLOp(),
     'do_malloc_fixedsize':LLOp(canmallocgc=True),
diff --git a/rpython/rtyper/lltypesystem/rstr.py 
b/rpython/rtyper/lltypesystem/rstr.py
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -374,7 +374,8 @@
         if not s:
             return 0
         x = s.hash
-        return jit.conditional_call_value(x == 0, LLHelpers._ll_strhash, x, s)
+        return jit.conditional_call(x == 0, LLHelpers._ll_strhash, s,
+                                    default=x)
 
     @staticmethod
     def ll_length(s):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to