[pypy-commit] pypy numpypy-nditer: implement op_dtypes
mattip
noreply at buildbot.pypy.org
Thu Apr 17 01:34:34 CEST 2014
Author: Matti Picus <matti.picus at gmail.com>
Branch: numpypy-nditer
Changeset: r70671:0dca5996f880
Date: 2014-04-16 23:10 +0300
http://bitbucket.org/pypy/pypy/changeset/0dca5996f880/
Log: implement op_dtypes
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -7,6 +7,7 @@
shape_agreement, shape_agreement_multiple)
from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
from pypy.module.micronumpy.concrete import SliceArray
+from pypy.module.micronumpy.descriptor import decode_w_dtype
from pypy.module.micronumpy import ufuncs
@@ -201,8 +202,8 @@
else:
raise NotImplementedError('not implemented yet')
-def get_iter(space, order, arr, shape):
- imp = arr.implementation
+def get_iter(space, order, arr, shape, dtype):
+ imp = arr.implementation.astype(space, dtype)
backward = is_backward(imp, order)
if (imp.strides[0] < imp.strides[-1] and not backward) or \
(imp.strides[0] > imp.strides[-1] and backward):
@@ -291,8 +292,13 @@
if not space.is_none(w_op_axes):
self.set_op_axes(space, w_op_axes)
if not space.is_none(w_op_dtypes):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'nditer op_dtypes kwarg not implemented yet'))
+ w_seq_as_list = space.listview(w_op_dtypes)
+ self.dtypes = [decode_w_dtype(space, w_elem) for w_elem in w_seq_as_list]
+ if len(self.dtypes) != len(self.seq):
+ raise OperationError(space.w_ValueError, space.wrap(
+ "op_dtypes must be a tuple/list matching the number of ops"))
+ else:
+ self.dtypes = []
self.iters=[]
outargs = [i for i in range(len(self.seq)) \
if self.seq[i] is None or self.op_flags[i].rw == 'w']
@@ -304,7 +310,7 @@
shape=out_shape)
if len(outargs) > 0:
# Make None operands writeonly and flagged for allocation
- out_dtype = None
+ out_dtype = self.dtypes[0] if len(self.dtypes) > 0 else None
for i in range(len(self.seq)):
if self.seq[i] is None:
self.op_flags[i].get_it_item = (get_readwrite_item,
@@ -331,6 +337,19 @@
else:
backward = self.order != self.tracked_index
self.index_iter = IndexIterator(iter_shape, backward=backward)
+ if len(self.dtypes) > 0:
+ # Make sure dtypes make sense
+ for i in range(len(self.seq)):
+ selfd = self.dtypes[i]
+ seq_d = self.seq[i].get_dtype()
+ if not selfd:
+ self.dtypes[i] = seq_d
+ elif selfd != seq_d and not 'r' in self.op_flags[i].tmp_copy:
+ raise OperationError(space.w_TypeError, space.wrap(
+ "Iterator operand required copying or buffering"))
+ else:
+ #copy them from seq
+ self.dtypes = [s.get_dtype() for s in self.seq]
if self.external_loop:
for i in range(len(self.seq)):
self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
@@ -338,7 +357,8 @@
else:
for i in range(len(self.seq)):
self.iters.append(BoxIterator(get_iter(space, self.order,
- self.seq[i], iter_shape), self.op_flags[i]))
+ self.seq[i], iter_shape, self.dtypes[i]),
+ self.op_flags[i]))
def set_op_axes(self, space, w_op_axes):
if space.len_w(w_op_axes) != len(self.seq):
diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -140,11 +140,7 @@
def test_op_dtype(self):
from numpy import arange, nditer, sqrt, array
- import sys
a = arange(6).reshape(2,3) - 3
- if '__pypy__' in sys.builtin_module_names:
- raises(NotImplementedError, nditer, a, op_dtypes=['complex'])
- skip('nditer op_dtypes kwarg not implemented yet')
exc = raises(TypeError, nditer, a, op_dtypes=['complex'])
assert str(exc.value).startswith("Iterator operand required copying or buffering")
r = []
@@ -154,7 +150,7 @@
assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
1+0j, 1.41421356237+0j]).sum()) < 1e-5
r = []
- for x in nditer(a, flags=['buffered'],
+ for x in nditer(a, op_flags=['copy'],
op_dtypes=['complex128']):
r.append(sqrt(x))
assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
More information about the pypy-commit
mailing list