Author: Brian Kearns <[email protected]>
Branch:
Changeset: r69546:7862a38e2a22
Date: 2014-02-28 08:23 -0500
http://bitbucket.org/pypy/pypy/changeset/7862a38e2a22/
Log: more optimizations for dot loop
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -79,7 +79,7 @@
class ArrayIter(object):
- _immutable_fields_ = ['array', 'size', 'indices', 'shape[*]',
+ _immutable_fields_ = ['array', 'size', 'indices', 'shape_m1[*]',
'strides[*]', 'backstrides[*]']
def __init__(self, array, size, shape, strides, backstrides):
@@ -87,7 +87,7 @@
self.array = array
self.size = size
self.indices = [0] * len(shape)
- self.shape = shape
+ self.shape_m1 = [s - 1 for s in shape]
self.strides = strides
self.backstrides = backstrides
self.reset()
@@ -95,15 +95,15 @@
@jit.unroll_safe
def reset(self):
self.index = 0
- for i in xrange(len(self.shape)):
+ for i in xrange(len(self.shape_m1)):
self.indices[i] = 0
self.offset = self.array.start
@jit.unroll_safe
def next(self):
self.index += 1
- for i in xrange(len(self.shape) - 1, -1, -1):
- if self.indices[i] < self.shape[i] - 1:
+ for i in xrange(len(self.shape_m1) - 1, -1, -1):
+ if self.indices[i] < self.shape_m1[i]:
self.indices[i] += 1
self.offset += self.strides[i]
break
@@ -117,14 +117,14 @@
if step == 0:
return
self.index += step
- for i in xrange(len(self.shape) - 1, -1, -1):
- if self.indices[i] < self.shape[i] - step:
+ for i in xrange(len(self.shape_m1) - 1, -1, -1):
+ if self.indices[i] < (self.shape_m1[i] + 1) - step:
self.indices[i] += step
self.offset += self.strides[i] * step
break
else:
- remaining_step = (self.indices[i] + step) // self.shape[i]
- this_i_step = step - remaining_step * self.shape[i]
+ remaining_step = (self.indices[i] + step) // (self.shape_m1[i]
+ 1)
+ this_i_step = step - remaining_step * (self.shape_m1[i] + 1)
self.indices[i] = self.indices[i] + this_i_step
self.offset += self.strides[i] * this_i_step
step = remaining_step
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -280,23 +280,30 @@
'''
left_shape = left.get_shape()
right_shape = right.get_shape()
+ left_impl = left.implementation
+ right_impl = right.implementation
assert left_shape[-1] == right_shape[right_critical_dim]
assert result.get_dtype() == dtype
outi = result.create_iter()
- lefti = AllButAxisIter(left.implementation, len(left_shape) - 1)
- righti = AllButAxisIter(right.implementation, right_critical_dim)
+ lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
+ righti = AllButAxisIter(right_impl, right_critical_dim)
+ n = left_impl.shape[-1]
+ s1 = left_impl.strides[-1]
+ s2 = right_impl.strides[right_critical_dim]
while not lefti.done():
while not righti.done():
oval = outi.getitem()
i1 = lefti.offset
i2 = righti.offset
- for _ in xrange(left.implementation.shape[-1]):
+ i = 0
+ while i < n:
+ i += 1
dot_driver.jit_merge_point(dtype=dtype)
- lval = left.implementation.getitem(i1).convert_to(space, dtype)
- rval = right.implementation.getitem(i2).convert_to(space,
dtype)
+ lval = left_impl.getitem(i1).convert_to(space, dtype)
+ rval = right_impl.getitem(i2).convert_to(space, dtype)
oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
- i1 += left.implementation.strides[-1]
- i2 += right.implementation.strides[right_critical_dim]
+ i1 += s1
+ i2 += s2
outi.setitem(oval)
outi.next()
righti.next()
diff --git a/pypy/module/micronumpy/test/test_zjit.py
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -512,36 +512,35 @@
self.check_simple_loop({'float_add': 1,
'float_mul': 1,
'guard_not_invalidated': 1,
- 'guard_false': 1,
+ 'guard_true': 1,
'int_add': 3,
- 'int_ge': 1,
+ 'int_lt': 1,
'jump': 1,
- 'raw_load': 2,
- 'setfield_gc': 1})
+ 'raw_load': 2})
self.check_resops({'arraylen_gc': 4,
'float_add': 2,
'float_mul': 2,
'getarrayitem_gc': 11,
'getarrayitem_gc_pure': 15,
- 'getfield_gc': 26,
- 'getfield_gc_pure': 32,
+ 'getfield_gc': 30,
+ 'getfield_gc_pure': 40,
'guard_class': 4,
- 'guard_false': 18,
+ 'guard_false': 14,
+ 'guard_nonnull': 8,
+ 'guard_nonnull_class': 4,
'guard_not_invalidated': 2,
- 'guard_true': 9,
+ 'guard_true': 13,
+ 'guard_value': 4,
'int_add': 25,
- 'int_ge': 8,
+ 'int_ge': 4,
'int_le': 8,
- 'int_lt': 7,
- 'int_sub': 15,
+ 'int_lt': 11,
+ 'int_sub': 8,
'jump': 3,
- 'new': 1,
- 'new_with_vtable': 1,
'raw_load': 6,
'raw_store': 1,
- 'same_as': 2,
'setarrayitem_gc': 10,
- 'setfield_gc': 19})
+ 'setfield_gc': 14})
def define_argsort():
return """
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit