Author: Armin Rigo <ar...@tunes.org> Branch: Changeset: r87044:6305cfb3bad2 Date: 2016-09-12 16:31 +0200 http://bitbucket.org/pypy/pypy/changeset/6305cfb3bad2/
Log: hg merge conditional_call_value_3 JIT residual calls: if the called function starts with a fast-path like "if x.foo != 0: return x.foo", then inline the check before doing the CALL. Right now only implemented on the x86 backend. Other backends specify supports_cond_call_value = False. diff --git a/pypy/module/pypyjit/test_pypy_c/test_containers.py b/pypy/module/pypyjit/test_pypy_c/test_containers.py --- a/pypy/module/pypyjit/test_pypy_c/test_containers.py +++ b/pypy/module/pypyjit/test_pypy_c/test_containers.py @@ -67,7 +67,7 @@ p10 = call_r(ConstClass(ll_str__IntegerR_SignedConst_Signed), i5, descr=<Callr . i EF=3>) guard_no_exception(descr=...) guard_nonnull(p10, descr=...) - i12 = call_i(ConstClass(ll_strhash), p10, descr=<Calli . r EF=0>) + i12 = call_i(ConstClass(_ll_strhash__rpy_stringPtr), p10, descr=<Calli . r EF=0>) p13 = new(descr=...) p15 = new_array_clear(16, descr=<ArrayU 1>) {{{ diff --git a/rpython/jit/backend/arm/regalloc.py b/rpython/jit/backend/arm/regalloc.py --- a/rpython/jit/backend/arm/regalloc.py +++ b/rpython/jit/backend/arm/regalloc.py @@ -1002,6 +1002,9 @@ prepare_op_cond_call_gc_wb_array = prepare_op_cond_call_gc_wb def prepare_op_cond_call(self, op, fcond): + # XXX don't force the arguments to be loaded in specific + # locations before knowing if we can take the fast path + # XXX add cond_call_value support assert 2 <= op.numargs() <= 4 + 2 tmpreg = self.get_scratch_reg(INT, selected_reg=r.r4) v = op.getarg(1) diff --git a/rpython/jit/backend/llgraph/runner.py b/rpython/jit/backend/llgraph/runner.py --- a/rpython/jit/backend/llgraph/runner.py +++ b/rpython/jit/backend/llgraph/runner.py @@ -325,6 +325,7 @@ supports_longlong = r_uint is not r_ulonglong supports_singlefloats = True supports_guard_gc_type = True + supports_cond_call_value = True translate_support_code = False is_llgraph = True vector_extension = True @@ -1334,6 +1335,16 @@ # cond_call can't have a return value self.execute_call_n(calldescr, func, *args) + def execute_cond_call_value_i(self, calldescr, value, func, *args): + if not value: + value = self.execute_call_i(calldescr, func, *args) + return value + + def execute_cond_call_value_r(self, calldescr, value, func, *args): + if not value: + value = self.execute_call_r(calldescr, func, *args) + return value + def _execute_call(self, calldescr, func, *args): effectinfo = calldescr.get_extra_info() if effectinfo is not None and hasattr(effectinfo, 'oopspecindex'): diff --git a/rpython/jit/backend/llsupport/regalloc.py b/rpython/jit/backend/llsupport/regalloc.py --- a/rpython/jit/backend/llsupport/regalloc.py +++ b/rpython/jit/backend/llsupport/regalloc.py @@ -759,6 +759,8 @@ if (opnum != rop.GUARD_TRUE and opnum != rop.GUARD_FALSE and opnum != rop.COND_CALL): return False + # NB: don't list COND_CALL_VALUE_I/R here, these two variants + # of COND_CALL don't accept a cc as input if next_op.getarg(0) is not op: return False if self.longevity[op][1] > i + 1: diff --git a/rpython/jit/backend/llsupport/rewrite.py b/rpython/jit/backend/llsupport/rewrite.py --- a/rpython/jit/backend/llsupport/rewrite.py +++ b/rpython/jit/backend/llsupport/rewrite.py @@ -11,7 +11,7 @@ from rpython.jit.backend.llsupport.symbolic import (WORD, get_array_token) from rpython.jit.backend.llsupport.descr import SizeDescr, ArrayDescr,\ - FLAG_POINTER + FLAG_POINTER, CallDescr from rpython.jit.metainterp.history import JitCellToken from rpython.jit.backend.llsupport.descr import (unpack_arraydescr, unpack_fielddescr, unpack_interiorfielddescr) @@ -370,7 +370,9 @@ self.consider_setfield_gc(op) elif op.getopnum() == rop.SETARRAYITEM_GC: self.consider_setarrayitem_gc(op) - # ---------- call assembler ----------- + # ---------- calls ----------- + if OpHelpers.is_plain_call(op.getopnum()): + self.expand_call_shortcut(op) if OpHelpers.is_call_assembler(op.getopnum()): self.handle_call_assembler(op) continue @@ -616,6 +618,30 @@ self.emit_gc_store_or_indexed(None, ptr, ConstInt(0), value, size, 1, ofs) + def expand_call_shortcut(self, op): + if not self.cpu.supports_cond_call_value: + return + descr = op.getdescr() + if descr is None: + return + assert isinstance(descr, CallDescr) + effectinfo = descr.get_extra_info() + if effectinfo is None or effectinfo.call_shortcut is None: + return + if op.type == 'r': + cond_call_opnum = rop.COND_CALL_VALUE_R + elif op.type == 'i': + cond_call_opnum = rop.COND_CALL_VALUE_I + else: + return + cs = effectinfo.call_shortcut + ptr_box = op.getarg(1 + cs.argnum) + value_box = self.emit_getfield(ptr_box, descr=cs.fielddescr, + raw=(ptr_box.type == 'i')) + self.replace_op_with(op, ResOperation(cond_call_opnum, + [value_box] + op.getarglist(), + descr=descr)) + def handle_call_assembler(self, op): descrs = self.gc_ll_descr.getframedescrs(self.cpu) loop_token = op.getdescr() diff --git a/rpython/jit/backend/llsupport/test/test_rewrite.py b/rpython/jit/backend/llsupport/test/test_rewrite.py --- a/rpython/jit/backend/llsupport/test/test_rewrite.py +++ b/rpython/jit/backend/llsupport/test/test_rewrite.py @@ -1,7 +1,8 @@ import py from rpython.jit.backend.llsupport.descr import get_size_descr,\ get_field_descr, get_array_descr, ArrayDescr, FieldDescr,\ - SizeDescr, get_interiorfield_descr + SizeDescr, get_interiorfield_descr, get_call_descr +from rpython.jit.codewriter.effectinfo import EffectInfo, CallShortcut from rpython.jit.backend.llsupport.gc import GcLLDescr_boehm,\ GcLLDescr_framework from rpython.jit.backend.llsupport import jitframe @@ -80,6 +81,14 @@ lltype.malloc(T, zero=True)) self.myT = myT # + call_shortcut = CallShortcut(0, tzdescr) + effectinfo = EffectInfo(None, None, None, None, None, None, + EffectInfo.EF_RANDOM_EFFECTS, + call_shortcut=call_shortcut) + call_shortcut_descr = get_call_descr(self.gc_ll_descr, + [lltype.Ptr(T)], lltype.Signed, + effectinfo) + # A = lltype.GcArray(lltype.Signed) adescr = get_array_descr(self.gc_ll_descr, A) adescr.tid = 4321 @@ -200,6 +209,7 @@ load_constant_offset = True load_supported_factors = (1,2,4,8) + supports_cond_call_value = True translate_support_code = None @@ -1429,3 +1439,15 @@ jump() """) assert len(self.gcrefs) == 2 + + def test_handle_call_shortcut(self): + self.check_rewrite(""" + [p0] + i1 = call_i(123, p0, descr=call_shortcut_descr) + jump(i1) + """, """ + [p0] + i2 = gc_load_i(p0, %(tzdescr.offset)s, %(tzdescr.field_size)s) + i1 = cond_call_value_i(i2, 123, p0, descr=call_shortcut_descr) + jump(i1) + """) diff --git a/rpython/jit/backend/model.py b/rpython/jit/backend/model.py --- a/rpython/jit/backend/model.py +++ b/rpython/jit/backend/model.py @@ -16,6 +16,7 @@ # Boxes and Consts are BoxFloats and ConstFloats. supports_singlefloats = False supports_guard_gc_type = False + supports_cond_call_value = False propagate_exception_descr = None diff --git a/rpython/jit/backend/test/runner_test.py b/rpython/jit/backend/test/runner_test.py --- a/rpython/jit/backend/test/runner_test.py +++ b/rpython/jit/backend/test/runner_test.py @@ -2389,7 +2389,7 @@ f2 = longlong.getfloatstorage(3.4) frame = self.cpu.execute_token(looptoken, 1, 0, 1, 2, 3, 4, 5, f1, f2) assert not called - for j in range(5): + for j in range(6): assert self.cpu.get_int_value(frame, j) == j assert longlong.getrealfloat(self.cpu.get_float_value(frame, 6)) == 1.2 assert longlong.getrealfloat(self.cpu.get_float_value(frame, 7)) == 3.4 @@ -2447,6 +2447,54 @@ 67, 89) assert called == [(67, 89)] + def test_cond_call_value(self): + if not self.cpu.supports_cond_call_value: + py.test.skip("missing supports_cond_call_value") + + def func_int(*args): + called.append(args) + return len(args) * 100 + 1000 + + for i in range(5): + called = [] + + FUNC = self.FuncType([lltype.Signed] * i, lltype.Signed) + func_ptr = llhelper(lltype.Ptr(FUNC), func_int) + calldescr = self.cpu.calldescrof(FUNC, FUNC.ARGS, FUNC.RESULT, + EffectInfo.MOST_GENERAL) + + ops = ''' + [i0, i1, i2, i3, i4, i5, i6, f0, f1] + i15 = cond_call_value_i(i1, ConstClass(func_ptr), %s) + guard_false(i0, descr=faildescr) [i1,i2,i3,i4,i5,i6,i15, f0,f1] + finish(i15) + ''' % ', '.join(['i%d' % (j + 2) for j in range(i)] + + ["descr=calldescr"]) + loop = parse(ops, namespace={'faildescr': BasicFailDescr(), + 'func_ptr': func_ptr, + 'calldescr': calldescr}) + looptoken = JitCellToken() + self.cpu.compile_loop(loop.inputargs, loop.operations, looptoken) + f1 = longlong.getfloatstorage(1.2) + f2 = longlong.getfloatstorage(3.4) + frame = self.cpu.execute_token(looptoken, 1, 50, 1, 2, 3, 4, 5, + f1, f2) + assert not called + assert [self.cpu.get_int_value(frame, j) for j in range(7)] == [ + 50, 1, 2, 3, 4, 5, 50] + assert longlong.getrealfloat( + self.cpu.get_float_value(frame, 7)) == 1.2 + assert longlong.getrealfloat( + self.cpu.get_float_value(frame, 8)) == 3.4 + # + frame = self.cpu.execute_token(looptoken, 1, 0, 1, 2, 3, 4, 5, + f1, f2) + assert called == [(1, 2, 3, 4)[:i]] + assert [self.cpu.get_int_value(frame, j) for j in range(7)] == [ + 0, 1, 2, 3, 4, 5, i * 100 + 1000] + assert longlong.getrealfloat(self.cpu.get_float_value(frame, 7)) == 1.2 + assert longlong.getrealfloat(self.cpu.get_float_value(frame, 8)) == 3.4 + def test_force_operations_returning_void(self): values = [] def maybe_force(token, flag): diff --git a/rpython/jit/backend/test/test_ll_random.py b/rpython/jit/backend/test/test_ll_random.py --- a/rpython/jit/backend/test/test_ll_random.py +++ b/rpython/jit/backend/test/test_ll_random.py @@ -594,7 +594,7 @@ return subset, d['f'], vtableptr def getresulttype(self): - if self.opnum == rop.CALL_I: + if self.opnum == rop.CALL_I or self.opnum == rop.COND_CALL_VALUE_I: return lltype.Signed elif self.opnum == rop.CALL_F: return lltype.Float @@ -712,7 +712,12 @@ class CondCallOperation(BaseCallOperation): def produce_into(self, builder, r): fail_subset = builder.subset_of_intvars(r) - v_cond = builder.get_bool_var(r) + if self.opnum == rop.COND_CALL: + RESULT_TYPE = lltype.Void + v_cond = builder.get_bool_var(r) + else: + RESULT_TYPE = lltype.Signed + v_cond = r.choice(builder.intvars) subset = builder.subset_of_intvars(r)[:4] for i in range(len(subset)): if r.random() < 0.35: @@ -724,8 +729,10 @@ seen.append(args) else: assert seen[0] == args + if RESULT_TYPE is lltype.Signed: + return len(args) - 42000 # - TP = lltype.FuncType([lltype.Signed] * len(subset), lltype.Void) + TP = lltype.FuncType([lltype.Signed] * len(subset), RESULT_TYPE) ptr = llhelper(lltype.Ptr(TP), call_me) c_addr = ConstAddr(llmemory.cast_ptr_to_adr(ptr), builder.cpu) args = [v_cond, c_addr] + subset @@ -769,6 +776,7 @@ for i in range(2): OPERATIONS.append(GuardClassOperation(rop.GUARD_CLASS)) OPERATIONS.append(CondCallOperation(rop.COND_CALL)) + OPERATIONS.append(CondCallOperation(rop.COND_CALL_VALUE_I)) OPERATIONS.append(RaisingCallOperation(rop.CALL_N)) OPERATIONS.append(RaisingCallOperationGuardNoException(rop.CALL_N)) OPERATIONS.append(RaisingCallOperationWrongGuardException(rop.CALL_N)) diff --git a/rpython/jit/backend/x86/assembler.py b/rpython/jit/backend/x86/assembler.py --- a/rpython/jit/backend/x86/assembler.py +++ b/rpython/jit/backend/x86/assembler.py @@ -174,8 +174,8 @@ # copy registers to the frame, with the exception of the # 'cond_call_register_arguments' and eax, because these have already # been saved by the caller. Note that this is not symmetrical: - # these 5 registers are saved by the caller but restored here at - # the end of this function. + # these 5 registers are saved by the caller but 4 of them are + # restored here at the end of this function. self._push_all_regs_to_frame(mc, cond_call_register_arguments + [eax], supports_floats, callee_only) # the caller already did push_gcmap(store=True) @@ -198,7 +198,7 @@ mc.ADD(esp, imm(WORD * 7)) self.set_extra_stack_depth(mc, 0) self.pop_gcmap(mc) # cancel the push_gcmap(store=True) in the caller - self._pop_all_regs_from_frame(mc, [], supports_floats, callee_only) + self._pop_all_regs_from_frame(mc, [eax], supports_floats, callee_only) mc.RET() return mc.materialize(self.cpu, []) @@ -1703,7 +1703,8 @@ self.implement_guard(guard_token) # If the previous operation was a COND_CALL, overwrite its conditional # jump to jump over this GUARD_NO_EXCEPTION as well, if we can - if self._find_nearby_operation(-1).getopnum() == rop.COND_CALL: + if self._find_nearby_operation(-1).getopnum() in ( + rop.COND_CALL, rop.COND_CALL_VALUE_I, rop.COND_CALL_VALUE_R): jmp_adr = self.previous_cond_call_jcond offset = self.mc.get_relative_pos() - jmp_adr if offset <= 127: @@ -2381,7 +2382,7 @@ def label(self): self._check_frame_depth_debug(self.mc) - def cond_call(self, op, gcmap, imm_func, arglocs): + def cond_call(self, gcmap, imm_func, arglocs, resloc=None): assert self.guard_success_cc >= 0 self.mc.J_il8(rx86.invert_condition(self.guard_success_cc), 0) # patched later @@ -2394,11 +2395,14 @@ # plus the register 'eax' base_ofs = self.cpu.get_baseofs_of_frame_field() should_be_saved = self._regalloc.rm.reg_bindings.values() + restore_eax = False for gpr in cond_call_register_arguments + [eax]: - if gpr not in should_be_saved: + if gpr not in should_be_saved or gpr is resloc: continue v = gpr_reg_mgr_cls.all_reg_indexes[gpr.value] self.mc.MOV_br(v * WORD + base_ofs, gpr.value) + if gpr is eax: + restore_eax = True # # load the 0-to-4 arguments into these registers from rpython.jit.backend.x86.jump import remap_frame_layout @@ -2422,8 +2426,16 @@ floats = True cond_call_adr = self.cond_call_slowpath[floats * 2 + callee_only] self.mc.CALL(imm(follow_jump(cond_call_adr))) + # if this is a COND_CALL_VALUE, we need to move the result in place + if resloc is not None and resloc is not eax: + self.mc.MOV(resloc, eax) # restoring the registers saved above, and doing pop_gcmap(), is left - # to the cond_call_slowpath helper. We never have any result value. + # to the cond_call_slowpath helper. We must only restore eax, if + # needed. + if restore_eax: + v = gpr_reg_mgr_cls.all_reg_indexes[eax.value] + self.mc.MOV_rb(eax.value, v * WORD + base_ofs) + # offset = self.mc.get_relative_pos() - jmp_adr assert 0 < offset <= 127 self.mc.overwrite(jmp_adr-1, chr(offset)) diff --git a/rpython/jit/backend/x86/regalloc.py b/rpython/jit/backend/x86/regalloc.py --- a/rpython/jit/backend/x86/regalloc.py +++ b/rpython/jit/backend/x86/regalloc.py @@ -938,16 +938,45 @@ self.rm.force_spill_var(box) assert box not in self.rm.reg_bindings # - assert op.type == 'v' args = op.getarglist() assert 2 <= len(args) <= 4 + 2 # maximum 4 arguments - v = args[1] - assert isinstance(v, Const) - imm_func = self.rm.convert_to_imm(v) + v_func = args[1] + assert isinstance(v_func, Const) + imm_func = self.rm.convert_to_imm(v_func) + + # Delicate ordering here. First get the argument's locations. + # If this also contains args[0], this returns the current + # location too. arglocs = [self.loc(args[i]) for i in range(2, len(args))] gcmap = self.get_gcmap() - self.load_condition_into_cc(op.getarg(0)) - self.assembler.cond_call(op, gcmap, imm_func, arglocs) + + if op.type == 'v': + # a plain COND_CALL. Calls the function when args[0] is + # true. Often used just after a comparison operation. + self.load_condition_into_cc(op.getarg(0)) + resloc = None + else: + # COND_CALL_VALUE_I/R. Calls the function when args[0] + # is equal to 0 or NULL. Returns the result from the + # function call if done, or args[0] if it was not 0/NULL. + # Implemented by forcing the result to live in the same + # register as args[0], and overwriting it if we really do + # the call. + + # Load the register for the result. Possibly reuse 'args[0]'. + # But the old value of args[0], if it survives, is first + # spilled away. We can't overwrite any of op.args[2:] here. + resloc = self.rm.force_result_in_reg(op, args[0], + forbidden_vars=args[2:]) + + # Test the register for the result. + self.assembler.test_location(resloc) + self.assembler.guard_success_cc = rx86.Conditions['Z'] + + self.assembler.cond_call(gcmap, imm_func, arglocs, resloc) + + consider_cond_call_value_i = consider_cond_call + consider_cond_call_value_r = consider_cond_call def consider_call_malloc_nursery(self, op): size_box = op.getarg(0) diff --git a/rpython/jit/backend/x86/runner.py b/rpython/jit/backend/x86/runner.py --- a/rpython/jit/backend/x86/runner.py +++ b/rpython/jit/backend/x86/runner.py @@ -15,6 +15,7 @@ debug = True supports_floats = True supports_singlefloats = True + supports_cond_call_value = True dont_keepalive_stuff = False # for tests with_threads = False diff --git a/rpython/jit/codewriter/call.py b/rpython/jit/codewriter/call.py --- a/rpython/jit/codewriter/call.py +++ b/rpython/jit/codewriter/call.py @@ -7,9 +7,10 @@ from rpython.jit.codewriter.jitcode import JitCode from rpython.jit.codewriter.effectinfo import (VirtualizableAnalyzer, QuasiImmutAnalyzer, RandomEffectsAnalyzer, effectinfo_from_writeanalyze, - EffectInfo, CallInfoCollection) + EffectInfo, CallInfoCollection, CallShortcut) from rpython.rtyper.lltypesystem import lltype, llmemory from rpython.rtyper.lltypesystem.lltype import getfunctionptr +from rpython.flowspace.model import Constant, Variable from rpython.rlib import rposix from rpython.translator.backendopt.canraise import RaiseAnalyzer from rpython.translator.backendopt.writeanalyze import ReadWriteAnalyzer @@ -214,6 +215,7 @@ elidable = False loopinvariant = False call_release_gil_target = EffectInfo._NO_CALL_RELEASE_GIL_TARGET + call_shortcut = None if op.opname == "direct_call": funcobj = op.args[0].value._obj assert getattr(funcobj, 'calling_conv', 'c') == 'c', ( @@ -228,6 +230,12 @@ tgt_func, tgt_saveerr = func._call_aroundstate_target_ tgt_func = llmemory.cast_ptr_to_adr(tgt_func) call_release_gil_target = (tgt_func, tgt_saveerr) + if hasattr(funcobj, 'graph'): + call_shortcut = self.find_call_shortcut(funcobj.graph) + if getattr(func, "_call_shortcut_", False): + assert call_shortcut is not None, ( + "%r: marked as @jit.call_shortcut but shortcut not found" + % (func,)) elif op.opname == 'indirect_call': # check that we're not trying to call indirectly some # function with the special flags @@ -298,6 +306,7 @@ self.readwrite_analyzer.analyze(op, self.seen_rw), self.cpu, extraeffect, oopspecindex, can_invalidate, call_release_gil_target, extradescr, self.collect_analyzer.analyze(op, self.seen_gc), + call_shortcut, ) # assert effectinfo is not None @@ -368,3 +377,65 @@ if GTYPE_fieldname in jd.greenfield_info.green_fields: return True return False + + def find_call_shortcut(self, graph): + """Identifies graphs that start like this: + + def graph(x, y, z): def graph(x, y, z): + if y.field: r = y.field + return y.field if r: return r + """ + block = graph.startblock + if len(block.operations) == 0: + return + op = block.operations[0] + if op.opname != 'getfield': + return + [v_inst, c_fieldname] = op.args + if not isinstance(v_inst, Variable): + return + v_result = op.result + if v_result.concretetype != graph.getreturnvar().concretetype: + return + if v_result.concretetype == lltype.Void: + return + argnum = i = 0 + while block.inputargs[i] is not v_inst: + if block.inputargs[i].concretetype != lltype.Void: + argnum += 1 + i += 1 + PSTRUCT = v_inst.concretetype + v_check = v_result + fastcase = True + for op in block.operations[1:]: + if (op.opname in ('int_is_true', 'ptr_nonzero', 'same_as') + and v_check is op.args[0]): + v_check = op.result + elif op.opname == 'ptr_iszero' and v_check is op.args[0]: + v_check = op.result + fastcase = not fastcase + elif (op.opname in ('int_eq', 'int_ne') + and v_check is op.args[0] + and isinstance(op.args[1], Constant) + and op.args[1].value == 0): + v_check = op.result + if op.opname == 'int_eq': + fastcase = not fastcase + else: + return + if v_check.concretetype is not lltype.Bool: + return + if block.exitswitch is not v_check: + return + + links = [link for link in block.exits if link.exitcase == fastcase] + if len(links) != 1: + return + [link] = links + if link.args != [v_result]: + return + if not link.target.is_final_block(): + return + + fielddescr = self.cpu.fielddescrof(PSTRUCT.TO, c_fieldname.value) + return CallShortcut(argnum, fielddescr) diff --git a/rpython/jit/codewriter/effectinfo.py b/rpython/jit/codewriter/effectinfo.py --- a/rpython/jit/codewriter/effectinfo.py +++ b/rpython/jit/codewriter/effectinfo.py @@ -117,7 +117,8 @@ can_invalidate=False, call_release_gil_target=_NO_CALL_RELEASE_GIL_TARGET, extradescrs=None, - can_collect=True): + can_collect=True, + call_shortcut=None): readonly_descrs_fields = frozenset_or_none(readonly_descrs_fields) readonly_descrs_arrays = frozenset_or_none(readonly_descrs_arrays) readonly_descrs_interiorfields = frozenset_or_none( @@ -135,7 +136,8 @@ extraeffect, oopspecindex, can_invalidate, - can_collect) + can_collect, + call_shortcut) tgt_func, tgt_saveerr = call_release_gil_target if tgt_func: key += (object(),) # don't care about caching in this case @@ -190,6 +192,7 @@ result.oopspecindex = oopspecindex result.extradescrs = extradescrs result.call_release_gil_target = call_release_gil_target + result.call_shortcut = call_shortcut if result.check_can_raise(ignore_memoryerror=True): assert oopspecindex in cls._OS_CANRAISE @@ -275,7 +278,8 @@ call_release_gil_target= EffectInfo._NO_CALL_RELEASE_GIL_TARGET, extradescr=None, - can_collect=True): + can_collect=True, + call_shortcut=None): from rpython.translator.backendopt.writeanalyze import top_set if effects is top_set or extraeffect == EffectInfo.EF_RANDOM_EFFECTS: readonly_descrs_fields = None @@ -364,7 +368,8 @@ can_invalidate, call_release_gil_target, extradescr, - can_collect) + can_collect, + call_shortcut) def consider_struct(TYPE, fieldname): if fieldType(TYPE, fieldname) is lltype.Void: @@ -387,6 +392,24 @@ # ____________________________________________________________ + +class CallShortcut(object): + def __init__(self, argnum, fielddescr): + self.argnum = argnum + self.fielddescr = fielddescr + + def __eq__(self, other): + return (isinstance(other, CallShortcut) and + self.argnum == other.argnum and + self.fielddescr == other.fielddescr) + def __ne__(self, other): + return not (self == other) + def __hash__(self): + return hash((self.argnum, self.fielddescr)) + +# ____________________________________________________________ + + class VirtualizableAnalyzer(BoolGraphAnalyzer): def analyze_simple_operation(self, op, graphinfo): return op.opname in ('jit_force_virtualizable', diff --git a/rpython/jit/codewriter/test/test_call.py b/rpython/jit/codewriter/test/test_call.py --- a/rpython/jit/codewriter/test/test_call.py +++ b/rpython/jit/codewriter/test/test_call.py @@ -6,7 +6,7 @@ from rpython.rlib import jit from rpython.jit.codewriter import support, call from rpython.jit.codewriter.call import CallControl -from rpython.jit.codewriter.effectinfo import EffectInfo +from rpython.jit.codewriter.effectinfo import EffectInfo, CallShortcut class FakePolicy: @@ -368,3 +368,100 @@ assert call_op.opname == 'direct_call' call_descr = cc.getcalldescr(call_op) assert call_descr.extrainfo.check_can_collect() == expected + +def test_find_call_shortcut(): + class FakeCPU: + def fielddescrof(self, TYPE, fieldname): + if isinstance(TYPE, lltype.GcStruct): + if fieldname == 'inst_foobar': + return 'foobardescr' + if fieldname == 'inst_fooref': + return 'foorefdescr' + if TYPE == RAW and fieldname == 'x': + return 'xdescr' + assert False, (TYPE, fieldname) + cc = CallControl(FakeCPU()) + + class B(object): + foobar = 0 + fooref = None + + def f1(a, b, c): + if b.foobar: + return b.foobar + b.foobar = a + c + return b.foobar + + def f2(x, y, z, b): + r = b.fooref + if r is not None: + return r + r = b.fooref = B() + return r + + class Space(object): + def _freeze_(self): + return True + space = Space() + + def f3(space, b): + r = b.foobar + if not r: + r = b.foobar = 123 + return r + + def f4(raw): + r = raw.x + if r != 0: + return r + raw.x = 123 + return 123 + RAW = lltype.Struct('RAW', ('x', lltype.Signed)) + + def f5(b): + r = b.foobar + if r == 0: + r = b.foobar = 123 + return r + + def f(a, c): + b = B() + f1(a, b, c) + f2(a, c, a, b) + f3(space, b) + r = lltype.malloc(RAW, flavor='raw') + f4(r) + f5(b) + + rtyper = support.annotate(f, [10, 20]) + f1_graph = rtyper.annotator.translator._graphof(f1) + assert cc.find_call_shortcut(f1_graph) == CallShortcut(1, "foobardescr") + f2_graph = rtyper.annotator.translator._graphof(f2) + assert cc.find_call_shortcut(f2_graph) == CallShortcut(3, "foorefdescr") + f3_graph = rtyper.annotator.translator._graphof(f3) + assert cc.find_call_shortcut(f3_graph) == CallShortcut(0, "foobardescr") + f4_graph = rtyper.annotator.translator._graphof(f4) + assert cc.find_call_shortcut(f4_graph) == CallShortcut(0, "xdescr") + f5_graph = rtyper.annotator.translator._graphof(f5) + assert cc.find_call_shortcut(f5_graph) == CallShortcut(0, "foobardescr") + +def test_cant_find_call_shortcut(): + from rpython.jit.backend.llgraph.runner import LLGraphCPU + + @jit.dont_look_inside + @jit.call_shortcut + def f1(n): + return n + 17 # no call shortcut found + + def f(n): + return f1(n) + + rtyper = support.annotate(f, [1]) + jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0]) + cc = CallControl(LLGraphCPU(rtyper), jitdrivers_sd=[jitdriver_sd]) + res = cc.find_all_graphs(FakePolicy()) + [f_graph] = [x for x in res if x.func is f] + call_op = f_graph.startblock.operations[0] + assert call_op.opname == 'direct_call' + e = py.test.raises(AssertionError, cc.getcalldescr, call_op) + assert "shortcut not found" in str(e.value) diff --git a/rpython/jit/metainterp/executor.py b/rpython/jit/metainterp/executor.py --- a/rpython/jit/metainterp/executor.py +++ b/rpython/jit/metainterp/executor.py @@ -101,6 +101,18 @@ if condbox.getint(): do_call_n(cpu, metainterp, argboxes[1:], descr) +def do_cond_call_value_i(cpu, metainterp, argboxes, descr): + value = argboxes[0].getint() + if value == 0: + value = do_call_i(cpu, metainterp, argboxes[1:], descr) + return value + +def do_cond_call_value_r(cpu, metainterp, argboxes, descr): + value = argboxes[0].getref_base() + if not value: + value = do_call_r(cpu, metainterp, argboxes[1:], descr) + return value + def do_getarrayitem_gc_i(cpu, _, arraybox, indexbox, arraydescr): array = arraybox.getref_base() index = indexbox.getint() @@ -366,6 +378,8 @@ rop.CALL_ASSEMBLER_I, rop.CALL_ASSEMBLER_N, rop.INCREMENT_DEBUG_COUNTER, + rop.COND_CALL_VALUE_R, + rop.COND_CALL_VALUE_I, rop.COND_CALL_GC_WB, rop.COND_CALL_GC_WB_ARRAY, rop.ZERO_ARRAY, diff --git a/rpython/jit/metainterp/resoperation.py b/rpython/jit/metainterp/resoperation.py --- a/rpython/jit/metainterp/resoperation.py +++ b/rpython/jit/metainterp/resoperation.py @@ -1149,8 +1149,8 @@ '_CANRAISE_FIRST', # ----- start of can_raise operations ----- '_CALL_FIRST', 'CALL/*d/rfin', - 'COND_CALL/*d/n', - # a conditional call, with first argument as a condition + 'COND_CALL/*d/n', # a conditional call, with first argument as a condition + 'COND_CALL_VALUE/*d/ri', # same but returns a result; emitted by rewrite 'CALL_ASSEMBLER/*d/rfin', # call already compiled assembler 'CALL_MAY_FORCE/*d/rfin', 'CALL_LOOPINVARIANT/*d/rfin', diff --git a/rpython/jit/metainterp/test/test_dict.py b/rpython/jit/metainterp/test/test_dict.py --- a/rpython/jit/metainterp/test/test_dict.py +++ b/rpython/jit/metainterp/test/test_dict.py @@ -195,7 +195,8 @@ 'new_with_vtable': 2, 'getinteriorfield_gc_i': 2, 'setfield_gc': 14, 'int_gt': 2, 'int_sub': 2, 'call_i': 6, 'call_n': 2, 'call_r': 2, 'int_ge': 2, - 'guard_no_exception': 8, 'new': 2}) + 'guard_no_exception': 8, 'new': 2, + 'guard_nonnull': 2}) def test_unrolling_of_dict_iter(self): driver = JitDriver(greens = [], reds = ['n']) diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py --- a/rpython/rlib/jit.py +++ b/rpython/rlib/jit.py @@ -257,6 +257,26 @@ func.oopspec = "jit.not_in_trace()" # note that 'func' may take arguments return func +def call_shortcut(func): + """A decorator to ensure that a function has a fast-path. + Only useful on functions that the JIT doesn't normally look inside. + It still replaces residual calls to that function with inline code + that checks for a fast path, and only does the call if not. For + now, graphs made by the following kinds of functions are detected: + + def func(x, y, z): def func(x, y, z): + if y.field: r = y.field + return y.field if r is None: + ... ... + return r + + Fast-path detection is always on, but this decorator makes the + codewriter complain if it cannot find the promized fast-path. + """ + func._call_shortcut_ = True + return func + + @oopspec("jit.isconstant(value)") def isconstant(value): """ diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py --- a/rpython/rtyper/lltypesystem/rstr.py +++ b/rpython/rtyper/lltypesystem/rstr.py @@ -369,13 +369,19 @@ return b @staticmethod + def ll_strhash(s): + if s: + return LLHelpers._ll_strhash(s) + else: + return 0 + + @staticmethod @jit.elidable - def ll_strhash(s): + @jit.call_shortcut + def _ll_strhash(s): # unlike CPython, there is no reason to avoid to return -1 # but our malloc initializes the memory to zero, so we use zero as the # special non-computed-yet value. - if not s: - return 0 x = s.hash if x == 0: x = _hash_string(s.chars) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit