[pypy-commit] pypy default: optimize multidim_dot loop
bdkearns
noreply at buildbot.pypy.org
Fri Feb 28 02:13:37 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r69530:e431aa28d934
Date: 2014-02-27 20:01 -0500
http://bitbucket.org/pypy/pypy/changeset/e431aa28d934/
Log: optimize multidim_dot loop
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -8,7 +8,8 @@
from rpython.rtyper.lltypesystem import lltype, rffi
from pypy.module.micronumpy import support, constants as NPY
from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter
+from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \
+ AllButAxisIter
call2_driver = jit.JitDriver(name='numpy_call2',
@@ -259,7 +260,6 @@
argmin = _new_argmin_argmax('min')
argmax = _new_argmin_argmax('max')
-# note that shapelen == 2 always
dot_driver = jit.JitDriver(name = 'numpy_dot',
greens = ['dtype'],
reds = 'auto')
@@ -280,25 +280,30 @@
'''
left_shape = left.get_shape()
right_shape = right.get_shape()
- broadcast_shape = left_shape[:-1] + right_shape
- left_skip = [len(left_shape) - 1 + i for i in range(len(right_shape))
- if i != right_critical_dim]
- right_skip = range(len(left_shape) - 1)
- result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+ assert left_shape[-1] == right_shape[right_critical_dim]
assert result.get_dtype() == dtype
- outi = result.implementation.create_dot_iter(broadcast_shape, result_skip)
- lefti = left.implementation.create_dot_iter(broadcast_shape, left_skip)
- righti = right.implementation.create_dot_iter(broadcast_shape, right_skip)
- while not outi.done():
- dot_driver.jit_merge_point(dtype=dtype)
- lval = lefti.getitem().convert_to(space, dtype)
- rval = righti.getitem().convert_to(space, dtype)
- outval = outi.getitem()
- v = dtype.itemtype.mul(lval, rval)
- v = dtype.itemtype.add(v, outval)
- outi.setitem(v)
- outi.next()
- righti.next()
+ outi = result.create_iter()
+ lefti = AllButAxisIter(left.implementation, len(left_shape) - 1)
+ righti = AllButAxisIter(right.implementation, right_critical_dim)
+ n = left.implementation.shape[-1]
+ s1 = left.implementation.strides[-1]
+ s2 = right.implementation.strides[right_critical_dim]
+ while not lefti.done():
+ while not righti.done():
+ oval = outi.getitem()
+ i1 = lefti.offset
+ i2 = righti.offset
+ for _ in xrange(n):
+ dot_driver.jit_merge_point(dtype=dtype)
+ lval = left.implementation.getitem(i1).convert_to(space, dtype)
+ rval = right.implementation.getitem(i2).convert_to(space, dtype)
+ oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
+ i1 += s1
+ i2 += s2
+ outi.setitem(oval)
+ outi.next()
+ righti.next()
+ righti.reset()
lefti.next()
return result
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -41,8 +41,7 @@
a[0] = 0
assert (b == [1, 1, 1, 0, 0]).all()
-
- def test_dot(self):
+ def test_dot_basic(self):
from numpypy import array, dot, arange
a = array(range(5))
assert dot(a, a) == 30.0
@@ -69,7 +68,7 @@
assert b.shape == (4, 3)
c = dot(a, b)
assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
- [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+ [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
c = dot(a, b[:, 2])
assert (c == [[62, 214, 366], [518, 670, 822]]).all()
a = arange(3*2*6).reshape((3,2,6))
More information about the pypy-commit
mailing list