[pypy-commit] pypy matrixmath-dot: experimental approach to dot iterator problem

mattip noreply at buildbot.pypy.org
Mon Dec 5 22:42:08 CET 2011


Author: mattip
Branch: matrixmath-dot
Changeset: r50188:885d36165f89
Date: 2011-12-04 23:22 +0200
http://bitbucket.org/pypy/pypy/changeset/885d36165f89/

Log:	experimental approach to dot iterator problem

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -121,6 +121,11 @@
     def get_offset(self):
         raise NotImplementedError
 
+class DummyIterator(object):
+    '''Dummy placeholder
+    '''
+    pass
+    
 class ArrayIterator(BaseIterator):
     def __init__(self, size):
         self.offset = 0
@@ -354,8 +359,10 @@
     descr_abs = _unaryop_impl("absolute")
 
     def _binop_impl(ufunc_name):
-        def impl(self, space, w_other):
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [self, w_other])
+        def impl(self, space, w_other, w_selfiter=DummyIterator(), 
+                                      w_otheriter=DummyIterator()):
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, 
+                                     [self, w_other, w_selfiter, w_otheriter])
         return func_with_new_name(impl, "binop_%s_impl" % ufunc_name)
 
     descr_add = _binop_impl("add")
@@ -980,12 +987,15 @@
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, signature, shape, calc_dtype, res_dtype, left, right):
+    def __init__(self, signature, shape, calc_dtype, res_dtype, left, right, 
+                     liter = DummyIterator(), riter = DummyIterator()):
         # XXX do something if left.order != right.order
         VirtualArray.__init__(self, signature, shape, res_dtype, left.order)
         self.left = left
         self.right = right
         self.calc_dtype = calc_dtype
+        self.liter = liter
+        self.riter = riter
         self.size = 1
         for s in self.shape:
             self.size *= s
@@ -1002,8 +1012,15 @@
             return self.forced_result.start_iter(res_shape)
         if res_shape is None:
             res_shape = self.shape  # we still force the shape on children
-        return Call2Iterator(self.left.start_iter(res_shape),
-                             self.right.start_iter(res_shape))
+        if not getattr(self.liter, 'get_offest', ''):
+            _liter = self.left.start_iter(res_shape)
+        else:
+            _liter = self.liter
+        if not getattr(self.riter, 'get_offest', ''):
+            _riter = self.right.start_iter(res_shape)
+        else:
+            _riter = self.riter
+        return Call2Iterator(_liter, _riter)
 
     def _eval(self, iter):
         assert isinstance(iter, Call2Iterator)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -139,8 +139,10 @@
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
             convert_to_array, Scalar, shape_agreement)
-
-        [w_lhs, w_rhs] = args_w
+        if len(args_w)>2:
+            [w_lhs, w_rhs, w_liter, w_riter] = args_w
+        else:
+            [w_lhs, w_rhs] = args_w
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
         calc_dtype = find_binop_result_dtype(space,
@@ -162,8 +164,12 @@
             self.signature, w_lhs.signature, w_rhs.signature
         ])
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
-        w_res = Call2(new_sig, new_shape, calc_dtype,
-                      res_dtype, w_lhs, w_rhs)
+        if len(args_w)>2:
+            w_res = Call2(new_sig, new_shape, calc_dtype,
+                          res_dtype, w_lhs, w_rhs, w_liter, w_riter)
+        else:
+            w_res = Call2(new_sig, new_shape, calc_dtype, 
+                          res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res


More information about the pypy-commit mailing list