[pypy-commit] pypy matrixmath-dot: refactor and rework, still need more tests
mattip
noreply at buildbot.pypy.org
Fri Jan 20 08:42:12 CET 2012
Author: mattip
Branch: matrixmath-dot
Changeset: r51506:2bcfa95fe92a
Date: 2012-01-20 09:41 +0200
http://bitbucket.org/pypy/pypy/changeset/2bcfa95fe92a/
Log: refactor and rework, still need more tests
diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/dot.py
@@ -0,0 +1,68 @@
+from pypy.module.micronumpy import interp_ufuncs
+from pypy.module.micronumpy.strides import calculate_dot_strides
+from pypy.interpreter.error import OperationError, operationerrfmt
+from pypy.module.micronumpy.interp_iter import ViewIterator
+
+
+def match_dot_shapes(space, left, right):
+ my_critical_dim_size = left.shape[-1]
+ right_critical_dim_size = right.shape[0]
+ right_critical_dim = 0
+ right_critical_dim_stride = right.strides[0]
+ out_shape = []
+ if len(right.shape) > 1:
+ right_critical_dim = len(right.shape) - 2
+ right_critical_dim_size = right.shape[right_critical_dim]
+ right_critical_dim_stride = right.strides[right_critical_dim]
+ assert right_critical_dim >= 0
+ out_shape += left.shape[:-1] + \
+ right.shape[0:right_critical_dim] + \
+ right.shape[right_critical_dim + 1:]
+ elif len(right.shape) > 0:
+ #dot does not reduce for scalars
+ out_shape += left.shape[:-1]
+ if my_critical_dim_size != right_critical_dim_size:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "objects are not aligned"))
+ return out_shape, right_critical_dim
+
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+ ''' assumes left, right are concrete arrays
+ given left.shape == [3, 5, 7],
+ right.shape == [2, 7, 4]
+ result.shape == [3, 5, 2, 4]
+ broadcast shape should be [3, 5, 2, 7, 4]
+ result should skip dims 3 which is results.ndims - 1
+ left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+ except where it==(right.ndims-2)
+ right should skip 0, 1
+ '''
+ mul = interp_ufuncs.get(space).multiply.func
+ add = interp_ufuncs.get(space).add.func
+ broadcast_shape = left.shape[:-1] + right.shape
+ left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
+ if i != right_critical_dim]
+ right_skip = range(len(left.shape) - 1)
+ result_skip = [len(result.shape) - 1]
+ shapelen = len(broadcast_shape)
+ _r = calculate_dot_strides(result.strides, result.backstrides,
+ broadcast_shape, result_skip)
+ outi = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+ _r = calculate_dot_strides(left.strides, left.backstrides,
+ broadcast_shape, left_skip)
+ lefti = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+ _r = calculate_dot_strides(right.strides, right.backstrides,
+ broadcast_shape, right_skip)
+ righti = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+ while not outi.done():
+ v = mul(dtype, left.getitem(lefti.offset),
+ right.getitem(righti.offset))
+ value = add(dtype, v, result.getitem(outi.offset))
+ result.setitem(outi.offset, value)
+ outi = outi.next(shapelen)
+ righti = righti.next(shapelen)
+ lefti = lefti.next(shapelen)
+ assert lefti.done()
+ assert righti.done()
+ return result
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -16,10 +16,6 @@
def __init__(self, res_shape):
self.res_shape = res_shape
-class DotTransform(BaseTransform):
- def __init__(self, res_shape, skip_dims):
- self.res_shape = res_shape
- self.skip_dims = skip_dims
class BaseIterator(object):
def next(self, shapelen):
@@ -90,10 +86,6 @@
self.strides,
self.backstrides, t.chunks)
return ViewIterator(r[1], r[2], r[3], r[0])
- elif isinstance(t, DotTransform):
- r = calculate_dot_strides(self.strides, self.backstrides,
- t.res_shape, t.skip_dims)
- return ViewIterator(self.offset, r[0], r[1], t.res_shape)
@jit.unroll_safe
def next(self, shapelen):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,14 +3,15 @@
from pypy.interpreter.gateway import interp2app, NoneNotWrapped
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature
-from pypy.module.micronumpy.strides import calculate_slice_strides,\
- calculate_dot_strides
+from pypy.module.micronumpy.strides import calculate_slice_strides
from pypy.rlib import jit
from pypy.rpython.lltypesystem import lltype, rffi
from pypy.tool.sourcetools import func_with_new_name
from pypy.rlib.rstring import StringBuilder
from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
- SkipLastAxisIterator, ViewIterator
+ SkipLastAxisIterator
+from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes, dot_docstring
+
numpy_driver = jit.JitDriver(
greens=['shapelen', 'sig'],
@@ -212,28 +213,6 @@
n_old_elems_to_use *= old_shape[oldI]
return new_strides
-def match_dot_shapes(space, self, other):
- my_critical_dim_size = self.shape[-1]
- other_critical_dim_size = other.shape[0]
- other_critical_dim = 0
- other_critical_dim_stride = other.strides[0]
- out_shape = []
- if len(other.shape) > 1:
- other_critical_dim = len(other.shape) - 2
- other_critical_dim_size = other.shape[other_critical_dim]
- other_critical_dim_stride = other.strides[other_critical_dim]
- assert other_critical_dim >= 0
- out_shape += self.shape[:-1] + \
- other.shape[0:other_critical_dim] + \
- other.shape[other_critical_dim + 1:]
- elif len(other.shape) > 0:
- #dot does not reduce for scalars
- out_shape += self.shape[:-1]
- if my_critical_dim_size != other_critical_dim_size:
- raise OperationError(space.w_ValueError, space.wrap(
- "objects are not aligned"))
- return out_shape, other_critical_dim
-
class BaseArray(Wrappable):
_attrs_ = ["invalidates", "shape", 'size']
@@ -399,14 +378,6 @@
descr_argmin = _reduce_argmax_argmin_impl("min")
def descr_dot(self, space, w_other):
- '''Dot product of two arrays.
-
- For 2-D arrays it is equivalent to matrix multiplication, and for 1-D
- arrays to inner product of vectors (without complex conjugation). For
- N dimensions it is a sum product over the last axis of `a` and
- the second-to-last of `b`::
-
- dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])'''
other = convert_to_array(space, w_other)
if isinstance(other, Scalar):
return self.descr_mul(space, other)
@@ -425,43 +396,10 @@
for o in out_shape:
out_size *= o
result = W_NDimArray(out_size, out_shape, dtype)
- # given a.shape == [3, 5, 7],
- # b.shape == [2, 7, 4]
- # result.shape == [3, 5, 2, 4]
- # all iterators shapes should be [3, 5, 2, 7, 4]
- # result should skip dims 3 which is results.ndims - 1
- # a should skip 2, 4 which is a.ndims-1 + range(b.ndims)
- # except where it==(b.ndims-2)
- # b should skip 0, 1
- mul = interp_ufuncs.get(space).multiply.func
- add = interp_ufuncs.get(space).add.func
- broadcast_shape = self.shape[:-1] + other.shape
- #Aww, cmon, this is the product of a warped mind.
- left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape)) if i != other_critical_dim]
- right_skip = range(len(self.shape) - 1)
- arr = DotArray(mul, 'DotName', out_shape, dtype, self, other,
- left_skip, right_skip)
- arr.broadcast_shape = broadcast_shape
- arr.result_skip = [len(out_shape) - 1]
- #Make this lazy someday...
- sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype,
- self.create_sig(), other.create_sig()), arr)
- assert isinstance(sig, signature.DotSignature)
- self.do_dot_loop(sig, result, arr, add)
- return result
-
- def do_dot_loop(self, sig, result, arr, add):
- frame = sig.create_frame(arr)
- shapelen = len(arr.broadcast_shape)
- _r = calculate_dot_strides(result.strides, result.backstrides,
- arr.broadcast_shape, arr.result_skip)
- ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape)
- while not frame.done():
- v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
- value = add(sig.calc_dtype, v, result.getitem(ri.offset))
- result.setitem(ri.offset, value)
- frame.next(shapelen)
- ri = ri.next(shapelen)
+ # This is the place to add fpypy and blas
+ return multidim_dot(space, self.get_concrete(),
+ other.get_concrete(), result, dtype,
+ other_critical_dim)
def get_concrete(self):
raise NotImplementedError
@@ -933,23 +871,6 @@
left, right)
self.dim = dim
-class DotArray(Call2):
- """ NOTE: this is only used as a container, you should never
- encounter such things in the wild. Remove this comment
- when we'll make Dot lazy
- """
- _immutable_fields_ = ['left', 'right']
-
- def __init__(self, ufunc, name, shape, dtype, left, right, left_skip, right_skip):
- Call2.__init__(self, ufunc, name, shape, dtype, dtype,
- left, right)
- self.left_skip = left_skip
- self.right_skip = right_skip
- def create_sig(self):
- #if self.forced_result is not None:
- # return self.forced_result.create_sig()
- assert NotImplementedError
-
class ConcreteArray(BaseArray):
""" An array that have actual storage, whether owned or not
"""
@@ -1304,6 +1225,8 @@
return space.wrap(arr)
def dot(space, w_obj, w_obj2):
+ '''see numpypy.dot. Does not exist as an ndarray method in numpy.
+ '''
w_arr = convert_to_array(space, w_obj)
if isinstance(w_arr, Scalar):
return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -2,7 +2,7 @@
from pypy.rlib.rarithmetic import intmask
from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
ConstantIterator, AxisIterator, ViewTransform,\
- BroadcastTransform, DotTransform
+ BroadcastTransform
from pypy.rlib.jit import hint, unroll_safe, promote
""" Signature specifies both the numpy expression that has been constructed
@@ -449,21 +449,3 @@
def debug_repr(self):
return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
-
-class DotSignature(Call2):
- def _invent_numbering(self, cache, allnumbers):
- self.left._invent_numbering(new_cache(), allnumbers)
- self.right._invent_numbering(new_cache(), allnumbers)
-
- def _create_iter(self, iterlist, arraylist, arr, transforms):
- from pypy.module.micronumpy.interp_numarray import DotArray
-
- assert isinstance(arr, DotArray)
- rtransforms = transforms + [DotTransform(arr.broadcast_shape, arr.right_skip)]
- ltransforms = transforms + [DotTransform(arr.broadcast_shape, arr.left_skip)]
- self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
- self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
-
- def debug_repr(self):
- return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(),
- self.left.debug_repr())
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -869,7 +869,7 @@
def test_dot(self):
from _numpypy import array, dot, arange
a = array(range(5))
- assert a.dot(a) == 30.0
+ assert dot(a, a) == 30.0
a = array(range(5))
assert a.dot(range(5)) == 30
@@ -887,9 +887,11 @@
#Superfluous shape test makes the intention of the test clearer
assert a.shape == (2, 3, 4)
assert b.shape == (4, 3)
- c = a.dot(b)
+ c = dot(a, b)
assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
[[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+ c = dot(a, b[:, :, 2])
+ assert (c == [[38, 126, 214], [302, 390, 478]]).all()
def test_dot_constant(self):
from _numpypy import array
More information about the pypy-commit
mailing list