[pypy-commit] pypy numpypy-nditer: pass a test

mattip noreply at buildbot.pypy.org
Sun Apr 6 23:43:09 CEST 2014


Author: Matti Picus <matti.picus at gmail.com>
Branch: numpypy-nditer
Changeset: r70479:f9af4c723ac6
Date: 2014-04-07 00:41 +0300
http://bitbucket.org/pypy/pypy/changeset/f9af4c723ac6/

Log:	pass a test

diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -4,7 +4,7 @@
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
-                                             shape_agreement_multiple)
+                                             shape_agreement, shape_agreement_multiple)
 from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
 from pypy.module.micronumpy.concrete import SliceArray
 from pypy.module.micronumpy import ufuncs
@@ -134,6 +134,9 @@
             op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
         elif op_flag.rw == 'rw':
             op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
+        elif op_flag.rw == 'w':
+            # XXX Extra logic needed to make sure writeonly
+            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
     return op_flag
 
 def parse_func_flags(space, nditer, w_flags):
@@ -154,8 +157,7 @@
         if item == 'external_loop':
             nditer.external_loop = True
         elif item == 'buffered':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                'nditer buffered not implemented yet'))
+            # For numpy compatability
             nditer.buffered = True
         elif item == 'c_index':
             nditer.tracked_index = 'C'
@@ -293,24 +295,35 @@
             raise OperationError(space.w_NotImplementedError, space.wrap(
                 'nditer op_dtypes kwarg not implemented yet'))
         self.iters=[]
-        self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
-        outarg = [i for i in range(len(self.seq)) if self.seq[i] is None]
-        if len(outarg) > 0:
-            # Make None operands writeonly and flagged for
-            # allocation, and everything else defaults to readonly.  To write
-            # to a provided operand, you must specify the write flag manually.
+        outargs = [i for i in range(len(self.seq)) \
+                        if self.seq[i] is None or self.op_flags[i].rw == 'w']
+        if len(outargs) > 0:
+            out_shape = shape_agreement_multiple(space, [self.seq[i] for i in outargs])
+        else:
+            out_shape = None
+        self.shape = iter_shape = shape_agreement_multiple(space, self.seq,
+                                                           shape=out_shape)
+        if len(outargs) > 0:
+            # Make None operands writeonly and flagged for allocation
             out_dtype = None
-            for elem in self.seq:
-                if elem is None:
+            for i in range(len(self.seq)):
+                if self.seq[i] is None:
+                    self.op_flags[i].get_it_item = (get_readwrite_item,
+                                                    get_readwrite_slice)
+                    self.op_flags[i].allocate = True
+                    continue
+                if self.op_flags[i] == 'w':
                     continue
                 out_dtype = ufuncs.find_binop_result_dtype(space,
-                                                elem.get_dtype(), out_dtype)
-            for i in outarg:
-                self.op_flags[i].get_it_item = (get_readwrite_item,
-                                                get_readwrite_slice)
-                self.op_flags[i].allocate = True
-                # XXX can we postpone allocation to later?
-                self.seq[i] = W_NDimArray.from_shape(space, iter_shape, out_dtype)
+                                                self.seq[i].get_dtype(), out_dtype)
+            for i in outargs:
+                if self.seq[i] is None:
+                    # XXX can we postpone allocation to later?
+                    self.seq[i] = W_NDimArray.from_shape(space, iter_shape, out_dtype)
+                else:
+                    if not self.op_flags[i].broadcast:
+                        # Raises if ooutput cannot be broadcast
+                        shape_agreement(space, iter_shape, self.seq[i], False)
         if self.tracked_index != "":
             if self.order == "K":
                 self.order = self.seq[0].implementation.order
@@ -430,8 +443,10 @@
             'not implemented yet'))
 
     def descr_get_operands(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        l_w = []
+        for op in self.seq:
+            l_w.append(op.descr_view(space))
+        return space.newlist(l_w)            
 
     def descr_get_dtypes(self, space):
         res = [None] * len(self.seq)
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -282,14 +282,16 @@
 
 
 @jit.unroll_safe
-def shape_agreement_multiple(space, array_list):
+def shape_agreement_multiple(space, array_list, shape=None):
     """ call shape_agreement recursively, allow elements from array_list to
     be None (like w_out)
     """
-    shape = array_list[0].get_shape()
-    for arr in array_list[1:]:
+    for arr in array_list:
         if not space.is_none(arr):
-            shape = shape_agreement(space, shape, arr)
+            if shape is None:
+                shape = arr.get_shape()
+            else:    
+                shape = shape_agreement(space, shape, arr)
     return shape
 
 
diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -211,7 +211,7 @@
         assert (c == [1., 4., 9.]).all()
         assert (b == c).all()
         exc = raises(ValueError, square2, arange(6).reshape(2, 3), out=b)
-        assert str(exc.value).startswith('non-broadcastable output')
+        assert str(exc.value).find('cannot be broadcasted') > 0
 
     def test_outer_product(self):
         from numpy import nditer, arange


More information about the pypy-commit mailing list