[pypy-commit] pypy numpypy-argminmax: fix, add more passing tests
mattip
noreply at buildbot.pypy.org
Wed Jul 4 22:10:24 CEST 2012
Author: mattip <matti.picus at gmail.com>
Branch: numpypy-argminmax
Changeset: r55921:881739f5e4db
Date: 2012-07-04 23:09 +0300
http://bitbucket.org/pypy/pypy/changeset/881739f5e4db/
Log: fix, add more passing tests
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -379,4 +379,4 @@
self.done = True
def get_dim_index(self):
- return self.indices[0]
+ return self.indices[self.dimorder[0]]
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -224,6 +224,9 @@
return out
return Scalar(res_dtype, res_dtype.box(result))
def do_axisminmax(self, space, axis, out):
+ # Use a AxisFirstIterator to walk along self, with dimensions
+ # reordered to move along 'axis' fastest. Every time 'axis' 's
+ # index is 0, move to the next value of out.
dtype = self.find_dtype()
source = AxisFirstIterator(self, axis)
dest = ViewIterator(out.start, out.strides, out.backstrides,
@@ -231,7 +234,6 @@
firsttime = True
while not source.done:
cur_val = self.getitem(source.offset)
- #print 'indices are',source.indices
cur_index = source.get_dim_index()
if cur_index == 0:
if not firsttime:
@@ -239,13 +241,11 @@
firsttime = False
cur_best = cur_val
out.setitem(dest.offset, dtype.box(0))
- #print 'setting out[',dest.offset,'] to 0'
else:
new_best = getattr(dtype.itemtype, op_name)(cur_best, cur_val)
if dtype.itemtype.ne(new_best, cur_best):
cur_best = new_best
out.setitem(dest.offset, dtype.box(cur_index))
- #print 'setting out[',dest.offset,'] to',cur_index
source.next()
return out
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1739,6 +1739,36 @@
assert a.argmax() == 5
assert a[:2, ].argmax() == 3
+ def test_argmax_axis(self):
+ from _numpypy import array
+ # Some random values, tested via cut-and-paste
+ # from numpy
+ vals = [57, 42, 57, 20, 81, 82, 65, 16, 52, 32,
+ 24, 95, 99, 4, 86, 60, 38, 28, 67, 45,
+ 68, 66, 13, 76, 98, 96, 61, 4, 0, 13,
+ 94, 30, 36, 89, 31, 54, 43, 6, 58, 84,
+ 15, 22, 41, 3, 49, 81, 65, 53, 85, 14,
+ 56, 37, 60, 11, 77, 9, 16, 80, 94, 43]
+ a = array(vals).reshape(5,3,4)
+ b = a.argmax(0)
+ assert (b == [[1, 2, 1, 3],
+ [0, 0, 2, 1],
+ [1, 2, 4, 0]]).all()
+ b = a.argmax(1)
+ assert (b == [[1, 1, 1, 2],
+ [0, 2, 0, 2],
+ [0, 0, 1, 2],
+ [2, 2, 2, 0],
+ [0, 2, 2, 2]]).all()
+ b = a.argmax(2)
+ assert (b == [[0, 1, 3], [0, 2, 3],
+ [0, 2, 1], [3, 2, 1],
+ [0, 2, 2]]).all()
+ b = a[:,2,:].argmax(1)
+ assert(b == [3, 3, 1, 1, 2]).all()
+
+
+
def test_broadcast_wrong_shapes(self):
from _numpypy import zeros
a = zeros((4, 3, 2))
More information about the pypy-commit
mailing list