[pypy-commit] pypy numpypy-nditer: implement op_dtypes

mattip noreply at buildbot.pypy.org
Thu Apr 17 01:34:34 CEST 2014


Author: Matti Picus <matti.picus at gmail.com>
Branch: numpypy-nditer
Changeset: r70671:0dca5996f880
Date: 2014-04-16 23:10 +0300
http://bitbucket.org/pypy/pypy/changeset/0dca5996f880/

Log:	implement op_dtypes

diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -7,6 +7,7 @@
                                              shape_agreement, shape_agreement_multiple)
 from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
 from pypy.module.micronumpy.concrete import SliceArray
+from pypy.module.micronumpy.descriptor import decode_w_dtype
 from pypy.module.micronumpy import ufuncs
 
 
@@ -201,8 +202,8 @@
     else:
         raise NotImplementedError('not implemented yet')
 
-def get_iter(space, order, arr, shape):
-    imp = arr.implementation
+def get_iter(space, order, arr, shape, dtype):
+    imp = arr.implementation.astype(space, dtype)
     backward = is_backward(imp, order)
     if (imp.strides[0] < imp.strides[-1] and not backward) or \
        (imp.strides[0] > imp.strides[-1] and backward):
@@ -291,8 +292,13 @@
         if not space.is_none(w_op_axes):
             self.set_op_axes(space, w_op_axes)
         if not space.is_none(w_op_dtypes):
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                'nditer op_dtypes kwarg not implemented yet'))
+            w_seq_as_list = space.listview(w_op_dtypes)
+            self.dtypes = [decode_w_dtype(space, w_elem) for w_elem in w_seq_as_list]
+            if len(self.dtypes) != len(self.seq):
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "op_dtypes must be a tuple/list matching the number of ops"))
+        else:
+            self.dtypes = []
         self.iters=[]
         outargs = [i for i in range(len(self.seq)) \
                         if self.seq[i] is None or self.op_flags[i].rw == 'w']
@@ -304,7 +310,7 @@
                                                            shape=out_shape)
         if len(outargs) > 0:
             # Make None operands writeonly and flagged for allocation
-            out_dtype = None
+            out_dtype = self.dtypes[0] if len(self.dtypes) > 0 else None
             for i in range(len(self.seq)):
                 if self.seq[i] is None:
                     self.op_flags[i].get_it_item = (get_readwrite_item,
@@ -331,6 +337,19 @@
             else:
                 backward = self.order != self.tracked_index
             self.index_iter = IndexIterator(iter_shape, backward=backward)
+        if len(self.dtypes) > 0:
+            # Make sure dtypes make sense
+            for i in range(len(self.seq)):
+                selfd = self.dtypes[i]
+                seq_d = self.seq[i].get_dtype()
+                if not selfd:
+                    self.dtypes[i] = seq_d
+                elif selfd != seq_d and not 'r' in self.op_flags[i].tmp_copy:
+                    raise OperationError(space.w_TypeError, space.wrap(
+                        "Iterator operand required copying or buffering"))
+        else:
+            #copy them from seq
+            self.dtypes = [s.get_dtype() for s in self.seq]
         if self.external_loop:
             for i in range(len(self.seq)):
                 self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
@@ -338,7 +357,8 @@
         else:
             for i in range(len(self.seq)):
                 self.iters.append(BoxIterator(get_iter(space, self.order,
-                                self.seq[i], iter_shape), self.op_flags[i]))
+                                    self.seq[i], iter_shape, self.dtypes[i]),
+                                 self.op_flags[i]))
 
     def set_op_axes(self, space, w_op_axes):
         if space.len_w(w_op_axes) != len(self.seq):
diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -140,11 +140,7 @@
 
     def test_op_dtype(self):
         from numpy import arange, nditer, sqrt, array
-        import sys
         a = arange(6).reshape(2,3) - 3
-        if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, nditer, a, op_dtypes=['complex'])
-            skip('nditer op_dtypes kwarg not implemented yet')
         exc = raises(TypeError, nditer, a, op_dtypes=['complex'])
         assert str(exc.value).startswith("Iterator operand required copying or buffering")
         r = []
@@ -154,7 +150,7 @@
         assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
                 1+0j, 1.41421356237+0j]).sum()) < 1e-5
         r = []
-        for x in nditer(a, flags=['buffered'],
+        for x in nditer(a, op_flags=['copy'],
                         op_dtypes=['complex128']):
             r.append(sqrt(x))
         assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,


More information about the pypy-commit mailing list