[pypy-commit] pypy matrixmath-dot: refactor and rework, still need more tests

mattip noreply at buildbot.pypy.org
Fri Jan 20 08:42:12 CET 2012


Author: mattip
Branch: matrixmath-dot
Changeset: r51506:2bcfa95fe92a
Date: 2012-01-20 09:41 +0200
http://bitbucket.org/pypy/pypy/changeset/2bcfa95fe92a/

Log:	refactor and rework, still need more tests

diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/dot.py
@@ -0,0 +1,68 @@
+from pypy.module.micronumpy import interp_ufuncs
+from pypy.module.micronumpy.strides import calculate_dot_strides
+from pypy.interpreter.error import OperationError, operationerrfmt
+from pypy.module.micronumpy.interp_iter import ViewIterator
+
+
+def match_dot_shapes(space, left, right):
+    my_critical_dim_size = left.shape[-1]
+    right_critical_dim_size = right.shape[0]
+    right_critical_dim = 0
+    right_critical_dim_stride = right.strides[0]
+    out_shape = []
+    if len(right.shape) > 1:
+        right_critical_dim = len(right.shape) - 2
+        right_critical_dim_size = right.shape[right_critical_dim]
+        right_critical_dim_stride = right.strides[right_critical_dim]
+        assert right_critical_dim >= 0
+        out_shape += left.shape[:-1] + \
+                     right.shape[0:right_critical_dim] + \
+                     right.shape[right_critical_dim + 1:]
+    elif len(right.shape) > 0:
+        #dot does not reduce for scalars
+        out_shape += left.shape[:-1]
+    if my_critical_dim_size != right_critical_dim_size:
+        raise OperationError(space.w_ValueError, space.wrap(
+                                        "objects are not aligned"))
+    return out_shape, right_critical_dim
+
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+    ''' assumes left, right are concrete arrays
+    given left.shape == [3, 5, 7],
+          right.shape == [2, 7, 4]
+     result.shape == [3, 5, 2, 4]
+    broadcast shape should be [3, 5, 2, 7, 4]
+    result should skip dims 3 which is results.ndims - 1
+    left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+          except where it==(right.ndims-2)
+    right should skip 0, 1
+    '''
+    mul = interp_ufuncs.get(space).multiply.func
+    add = interp_ufuncs.get(space).add.func
+    broadcast_shape = left.shape[:-1] + right.shape
+    left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
+                                         if i != right_critical_dim]
+    right_skip = range(len(left.shape) - 1)
+    result_skip = [len(result.shape) - 1]
+    shapelen = len(broadcast_shape)
+    _r = calculate_dot_strides(result.strides, result.backstrides,
+                                  broadcast_shape, result_skip)
+    outi = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+    _r = calculate_dot_strides(left.strides, left.backstrides,
+                                  broadcast_shape, left_skip)
+    lefti = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+    _r = calculate_dot_strides(right.strides, right.backstrides,
+                                  broadcast_shape, right_skip)
+    righti = ViewIterator(0, _r[0], _r[1], broadcast_shape)
+    while not outi.done():
+        v = mul(dtype, left.getitem(lefti.offset),
+                       right.getitem(righti.offset))
+        value = add(dtype, v, result.getitem(outi.offset))
+        result.setitem(outi.offset, value)
+        outi = outi.next(shapelen)
+        righti = righti.next(shapelen)
+        lefti = lefti.next(shapelen)
+    assert lefti.done()
+    assert righti.done()
+    return result
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -16,10 +16,6 @@
     def __init__(self, res_shape):
         self.res_shape = res_shape
 
-class DotTransform(BaseTransform):
-    def __init__(self, res_shape, skip_dims):
-        self.res_shape = res_shape
-        self.skip_dims = skip_dims
 
 class BaseIterator(object):
     def next(self, shapelen):
@@ -90,10 +86,6 @@
                                         self.strides,
                                         self.backstrides, t.chunks)
             return ViewIterator(r[1], r[2], r[3], r[0])
-        elif isinstance(t, DotTransform):
-            r = calculate_dot_strides(self.strides, self.backstrides,
-                                     t.res_shape, t.skip_dims)
-            return ViewIterator(self.offset, r[0], r[1], t.res_shape)
 
     @jit.unroll_safe
     def next(self, shapelen):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,14 +3,15 @@
 from pypy.interpreter.gateway import interp2app, NoneNotWrapped
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature
-from pypy.module.micronumpy.strides import calculate_slice_strides,\
-                                           calculate_dot_strides
+from pypy.module.micronumpy.strides import calculate_slice_strides
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
-     SkipLastAxisIterator, ViewIterator
+     SkipLastAxisIterator
+from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes, dot_docstring
+
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
@@ -212,28 +213,6 @@
                 n_old_elems_to_use *= old_shape[oldI]
     return new_strides
 
-def match_dot_shapes(space, self, other):
-    my_critical_dim_size = self.shape[-1]
-    other_critical_dim_size = other.shape[0]
-    other_critical_dim = 0
-    other_critical_dim_stride = other.strides[0]
-    out_shape = []
-    if len(other.shape) > 1:
-        other_critical_dim = len(other.shape) - 2
-        other_critical_dim_size = other.shape[other_critical_dim]
-        other_critical_dim_stride = other.strides[other_critical_dim]
-        assert other_critical_dim >= 0
-        out_shape += self.shape[:-1] + \
-                     other.shape[0:other_critical_dim] + \
-                     other.shape[other_critical_dim + 1:]
-    elif len(other.shape) > 0:
-        #dot does not reduce for scalars
-        out_shape += self.shape[:-1]
-    if my_critical_dim_size != other_critical_dim_size:
-        raise OperationError(space.w_ValueError, space.wrap(
-                                        "objects are not aligned"))
-    return out_shape, other_critical_dim
-
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "shape", 'size']
 
@@ -399,14 +378,6 @@
     descr_argmin = _reduce_argmax_argmin_impl("min")
 
     def descr_dot(self, space, w_other):
-        '''Dot product of two arrays.
-
-    For 2-D arrays it is equivalent to matrix multiplication, and for 1-D
-    arrays to inner product of vectors (without complex conjugation). For
-    N dimensions it is a sum product over the last axis of `a` and
-    the second-to-last of `b`::
-
-        dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])'''
         other = convert_to_array(space, w_other)
         if isinstance(other, Scalar):
             return self.descr_mul(space, other)
@@ -425,43 +396,10 @@
         for o in out_shape:
             out_size *= o
         result = W_NDimArray(out_size, out_shape, dtype)
-        # given a.shape == [3, 5, 7],
-        #       b.shape == [2, 7, 4]
-        #  result.shape == [3, 5, 2, 4]
-        # all iterators shapes should be [3, 5, 2, 7, 4]
-        # result should skip dims 3 which is results.ndims - 1
-        # a should skip 2, 4 which is a.ndims-1 + range(b.ndims) 
-        #       except where it==(b.ndims-2)
-        # b should skip 0, 1
-        mul = interp_ufuncs.get(space).multiply.func
-        add = interp_ufuncs.get(space).add.func
-        broadcast_shape = self.shape[:-1] + other.shape
-        #Aww, cmon, this is the product of a warped mind.
-        left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape)) if i != other_critical_dim]
-        right_skip = range(len(self.shape) - 1)
-        arr = DotArray(mul, 'DotName', out_shape, dtype, self, other,
-                                        left_skip, right_skip)
-        arr.broadcast_shape = broadcast_shape
-        arr.result_skip = [len(out_shape) - 1]
-        #Make this lazy someday...
-        sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype,
-                                  self.create_sig(), other.create_sig()), arr)
-        assert isinstance(sig, signature.DotSignature)
-        self.do_dot_loop(sig, result, arr, add)
-        return result
-
-    def do_dot_loop(self, sig, result, arr, add):
-        frame = sig.create_frame(arr)
-        shapelen = len(arr.broadcast_shape)
-        _r = calculate_dot_strides(result.strides, result.backstrides,
-                                      arr.broadcast_shape, arr.result_skip)
-        ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape)
-        while not frame.done():
-            v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
-            value = add(sig.calc_dtype, v, result.getitem(ri.offset))
-            result.setitem(ri.offset, value)
-            frame.next(shapelen)
-            ri = ri.next(shapelen)
+        # This is the place to add fpypy and blas
+        return multidim_dot(space, self.get_concrete(), 
+                            other.get_concrete(), result, dtype,
+                            other_critical_dim)
 
     def get_concrete(self):
         raise NotImplementedError
@@ -933,23 +871,6 @@
                        left, right)
         self.dim = dim
 
-class DotArray(Call2):
-    """ NOTE: this is only used as a container, you should never
-    encounter such things in the wild. Remove this comment
-    when we'll make Dot lazy
-    """
-    _immutable_fields_ = ['left', 'right']
-    
-    def __init__(self, ufunc, name, shape, dtype, left, right, left_skip, right_skip):
-        Call2.__init__(self, ufunc, name, shape, dtype, dtype,
-                       left, right)
-        self.left_skip = left_skip
-        self.right_skip = right_skip
-    def create_sig(self):
-        #if self.forced_result is not None:
-        #    return self.forced_result.create_sig()
-        assert NotImplementedError 
-
 class ConcreteArray(BaseArray):
     """ An array that have actual storage, whether owned or not
     """
@@ -1304,6 +1225,8 @@
     return space.wrap(arr)
 
 def dot(space, w_obj, w_obj2):
+    '''see numpypy.dot. Does not exist as an ndarray method in numpy.
+    '''
     w_arr = convert_to_array(space, w_obj)
     if isinstance(w_arr, Scalar):
         return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -2,7 +2,7 @@
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
      ConstantIterator, AxisIterator, ViewTransform,\
-     BroadcastTransform, DotTransform
+     BroadcastTransform
 from pypy.rlib.jit import hint, unroll_safe, promote
 
 """ Signature specifies both the numpy expression that has been constructed
@@ -449,21 +449,3 @@
     
     def debug_repr(self):
         return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
-
-class DotSignature(Call2):
-    def _invent_numbering(self, cache, allnumbers):
-        self.left._invent_numbering(new_cache(), allnumbers)
-        self.right._invent_numbering(new_cache(), allnumbers)
-
-    def _create_iter(self, iterlist, arraylist, arr, transforms):
-        from pypy.module.micronumpy.interp_numarray import DotArray
-
-        assert isinstance(arr, DotArray)
-        rtransforms = transforms + [DotTransform(arr.broadcast_shape, arr.right_skip)]
-        ltransforms = transforms + [DotTransform(arr.broadcast_shape, arr.left_skip)]
-        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
-        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
-
-    def debug_repr(self):
-        return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(),
-						 self.left.debug_repr())
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -869,7 +869,7 @@
     def test_dot(self):
         from _numpypy import array, dot, arange
         a = array(range(5))
-        assert a.dot(a) == 30.0
+        assert dot(a, a) == 30.0
 
         a = array(range(5))
         assert a.dot(range(5)) == 30
@@ -887,9 +887,11 @@
         #Superfluous shape test makes the intention of the test clearer
         assert a.shape == (2, 3, 4)
         assert b.shape == (4, 3)
-        c = a.dot(b)
+        c = dot(a, b)
         assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
                    [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+        c = dot(a, b[:, :, 2])
+        assert (c == [[38, 126, 214], [302, 390, 478]]).all()
 
     def test_dot_constant(self):
         from _numpypy import array


More information about the pypy-commit mailing list