[pypy-commit] pypy default: optimize multidim_dot loop

bdkearns noreply at buildbot.pypy.org
Fri Feb 28 02:13:37 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r69530:e431aa28d934
Date: 2014-02-27 20:01 -0500
http://bitbucket.org/pypy/pypy/changeset/e431aa28d934/

Log:	optimize multidim_dot loop

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -8,7 +8,8 @@
 from rpython.rtyper.lltypesystem import lltype, rffi
 from pypy.module.micronumpy import support, constants as NPY
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter
+from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \
+    AllButAxisIter
 
 
 call2_driver = jit.JitDriver(name='numpy_call2',
@@ -259,7 +260,6 @@
 argmin = _new_argmin_argmax('min')
 argmax = _new_argmin_argmax('max')
 
-# note that shapelen == 2 always
 dot_driver = jit.JitDriver(name = 'numpy_dot',
                            greens = ['dtype'],
                            reds = 'auto')
@@ -280,25 +280,30 @@
     '''
     left_shape = left.get_shape()
     right_shape = right.get_shape()
-    broadcast_shape = left_shape[:-1] + right_shape
-    left_skip = [len(left_shape) - 1 + i for i in range(len(right_shape))
-                                         if i != right_critical_dim]
-    right_skip = range(len(left_shape) - 1)
-    result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+    assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
-    outi = result.implementation.create_dot_iter(broadcast_shape, result_skip)
-    lefti = left.implementation.create_dot_iter(broadcast_shape, left_skip)
-    righti = right.implementation.create_dot_iter(broadcast_shape, right_skip)
-    while not outi.done():
-        dot_driver.jit_merge_point(dtype=dtype)
-        lval = lefti.getitem().convert_to(space, dtype)
-        rval = righti.getitem().convert_to(space, dtype)
-        outval = outi.getitem()
-        v = dtype.itemtype.mul(lval, rval)
-        v = dtype.itemtype.add(v, outval)
-        outi.setitem(v)
-        outi.next()
-        righti.next()
+    outi = result.create_iter()
+    lefti = AllButAxisIter(left.implementation, len(left_shape) - 1)
+    righti = AllButAxisIter(right.implementation, right_critical_dim)
+    n = left.implementation.shape[-1]
+    s1 = left.implementation.strides[-1]
+    s2 = right.implementation.strides[right_critical_dim]
+    while not lefti.done():
+        while not righti.done():
+            oval = outi.getitem()
+            i1 = lefti.offset
+            i2 = righti.offset
+            for _ in xrange(n):
+                dot_driver.jit_merge_point(dtype=dtype)
+                lval = left.implementation.getitem(i1).convert_to(space, dtype)
+                rval = right.implementation.getitem(i2).convert_to(space, dtype)
+                oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
+                i1 += s1
+                i2 += s2
+            outi.setitem(oval)
+            outi.next()
+            righti.next()
+        righti.reset()
         lefti.next()
     return result
 
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -41,8 +41,7 @@
         a[0] = 0
         assert (b == [1, 1, 1, 0, 0]).all()
 
-
-    def test_dot(self):
+    def test_dot_basic(self):
         from numpypy import array, dot, arange
         a = array(range(5))
         assert dot(a, a) == 30.0
@@ -69,7 +68,7 @@
         assert b.shape == (4, 3)
         c = dot(a, b)
         assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
-                   [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+                      [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
         c = dot(a, b[:, 2])
         assert (c == [[62, 214, 366], [518, 670, 822]]).all()
         a = arange(3*2*6).reshape((3,2,6))


More information about the pypy-commit mailing list