Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r69546:7862a38e2a22
Date: 2014-02-28 08:23 -0500
http://bitbucket.org/pypy/pypy/changeset/7862a38e2a22/

Log:    more optimizations for dot loop

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -79,7 +79,7 @@
 
 
 class ArrayIter(object):
-    _immutable_fields_ = ['array', 'size', 'indices', 'shape[*]',
+    _immutable_fields_ = ['array', 'size', 'indices', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]']
 
     def __init__(self, array, size, shape, strides, backstrides):
@@ -87,7 +87,7 @@
         self.array = array
         self.size = size
         self.indices = [0] * len(shape)
-        self.shape = shape
+        self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
         self.backstrides = backstrides
         self.reset()
@@ -95,15 +95,15 @@
     @jit.unroll_safe
     def reset(self):
         self.index = 0
-        for i in xrange(len(self.shape)):
+        for i in xrange(len(self.shape_m1)):
             self.indices[i] = 0
         self.offset = self.array.start
 
     @jit.unroll_safe
     def next(self):
         self.index += 1
-        for i in xrange(len(self.shape) - 1, -1, -1):
-            if self.indices[i] < self.shape[i] - 1:
+        for i in xrange(len(self.shape_m1) - 1, -1, -1):
+            if self.indices[i] < self.shape_m1[i]:
                 self.indices[i] += 1
                 self.offset += self.strides[i]
                 break
@@ -117,14 +117,14 @@
         if step == 0:
             return
         self.index += step
-        for i in xrange(len(self.shape) - 1, -1, -1):
-            if self.indices[i] < self.shape[i] - step:
+        for i in xrange(len(self.shape_m1) - 1, -1, -1):
+            if self.indices[i] < (self.shape_m1[i] + 1) - step:
                 self.indices[i] += step
                 self.offset += self.strides[i] * step
                 break
             else:
-                remaining_step = (self.indices[i] + step) // self.shape[i]
-                this_i_step = step - remaining_step * self.shape[i]
+                remaining_step = (self.indices[i] + step) // (self.shape_m1[i] 
+ 1)
+                this_i_step = step - remaining_step * (self.shape_m1[i] + 1)
                 self.indices[i] = self.indices[i] + this_i_step
                 self.offset += self.strides[i] * this_i_step
                 step = remaining_step
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -280,23 +280,30 @@
     '''
     left_shape = left.get_shape()
     right_shape = right.get_shape()
+    left_impl = left.implementation
+    right_impl = right.implementation
     assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
     outi = result.create_iter()
-    lefti = AllButAxisIter(left.implementation, len(left_shape) - 1)
-    righti = AllButAxisIter(right.implementation, right_critical_dim)
+    lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
+    righti = AllButAxisIter(right_impl, right_critical_dim)
+    n = left_impl.shape[-1]
+    s1 = left_impl.strides[-1]
+    s2 = right_impl.strides[right_critical_dim]
     while not lefti.done():
         while not righti.done():
             oval = outi.getitem()
             i1 = lefti.offset
             i2 = righti.offset
-            for _ in xrange(left.implementation.shape[-1]):
+            i = 0
+            while i < n:
+                i += 1
                 dot_driver.jit_merge_point(dtype=dtype)
-                lval = left.implementation.getitem(i1).convert_to(space, dtype)
-                rval = right.implementation.getitem(i2).convert_to(space, 
dtype)
+                lval = left_impl.getitem(i1).convert_to(space, dtype)
+                rval = right_impl.getitem(i2).convert_to(space, dtype)
                 oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
-                i1 += left.implementation.strides[-1]
-                i2 += right.implementation.strides[right_critical_dim]
+                i1 += s1
+                i2 += s2
             outi.setitem(oval)
             outi.next()
             righti.next()
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -512,36 +512,35 @@
         self.check_simple_loop({'float_add': 1,
                                 'float_mul': 1,
                                 'guard_not_invalidated': 1,
-                                'guard_false': 1,
+                                'guard_true': 1,
                                 'int_add': 3,
-                                'int_ge': 1,
+                                'int_lt': 1,
                                 'jump': 1,
-                                'raw_load': 2,
-                                'setfield_gc': 1})
+                                'raw_load': 2})
         self.check_resops({'arraylen_gc': 4,
                            'float_add': 2,
                            'float_mul': 2,
                            'getarrayitem_gc': 11,
                            'getarrayitem_gc_pure': 15,
-                           'getfield_gc': 26,
-                           'getfield_gc_pure': 32,
+                           'getfield_gc': 30,
+                           'getfield_gc_pure': 40,
                            'guard_class': 4,
-                           'guard_false': 18,
+                           'guard_false': 14,
+                           'guard_nonnull': 8,
+                           'guard_nonnull_class': 4,
                            'guard_not_invalidated': 2,
-                           'guard_true': 9,
+                           'guard_true': 13,
+                           'guard_value': 4,
                            'int_add': 25,
-                           'int_ge': 8,
+                           'int_ge': 4,
                            'int_le': 8,
-                           'int_lt': 7,
-                           'int_sub': 15,
+                           'int_lt': 11,
+                           'int_sub': 8,
                            'jump': 3,
-                           'new': 1,
-                           'new_with_vtable': 1,
                            'raw_load': 6,
                            'raw_store': 1,
-                           'same_as': 2,
                            'setarrayitem_gc': 10,
-                           'setfield_gc': 19})
+                           'setfield_gc': 14})
 
     def define_argsort():
         return """
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to