[pypy-commit] pypy numpy-multidim: argmax works for multidim now
alex_gaynor
noreply at buildbot.pypy.org
Thu Nov 24 14:43:01 CET 2011
Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-multidim
Changeset: r49728:4f7747b4ff1c
Date: 2011-11-24 07:42 -0600
http://bitbucket.org/pypy/pypy/changeset/4f7747b4ff1c/
Log: argmax works for multidim now
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -389,36 +389,35 @@
def _reduce_argmax_argmin_impl(op_name):
reduce_driver = jit.JitDriver(greens=['shapelen', 'signature'],
- reds=['result', 'i', 'self', 'cur_best', 'dtype'])
+ reds=['result', 'idx', 'i', 'self', 'cur_best', 'dtype'])
def loop(self):
i = self.start_iter()
- result = i.get_offset()
cur_best = self.eval(i)
shapelen = len(self.shape)
i = i.next(shapelen)
dtype = self.find_dtype()
+ result = 0
+ idx = 1
while not i.done():
reduce_driver.jit_merge_point(signature=self.signature,
shapelen=shapelen,
self=self, dtype=dtype,
- i=i, result=result,
+ i=i, result=result, idx=idx,
cur_best=cur_best)
new_best = getattr(dtype, op_name)(cur_best, self.eval(i))
if dtype.ne(new_best, cur_best):
- result = i.get_offset()
+ result = idx
cur_best = new_best
i = i.next(shapelen)
+ idx += 1
return result
def impl(self, space):
size = self.find_size()
- if len(self.shape) > 1:
- raise OperationError(space.w_TypeError,
- space.wrap("argmin/max does not work on multidimensional arrays yet"))
if size == 0:
raise OperationError(space.w_ValueError,
space.wrap("Can't call %s on zero-size arrays" \
% op_name))
- return self.compute_index(space, loop(self))
+ return space.wrap(loop(self))
return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
def _all(self):
@@ -734,11 +733,6 @@
def start_iter(self, res_shape=None):
raise NotImplementedError
- def compute_index(self, space, offset):
- offset -= self.start
- assert len(self.shape) == 1
- return space.wrap(offset // self.strides[0])
-
def convert_to_array(space, w_obj):
if isinstance(w_obj, BaseArray):
return w_obj
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -640,12 +640,7 @@
r = a.argmax()
assert r == 2
b = array([])
- try:
- b.argmax()
- except:
- pass
- else:
- raise Exception("Did not raise")
+ raises(ValueError, b.argmax)
a = array(range(-5, 5))
r = a.argmax()
@@ -909,8 +904,9 @@
def test_argmax(self):
from numpypy import array
- a = array([[1, 2], [3, 4]])
- raises(TypeError, a.argmax)
+ a = array([[1, 2], [3, 4], [5, 6]])
+ assert a.argmax() == 5
+ assert a[:2,].argmax() == 3
class AppTestSupport(object):
def setup_class(cls):
@@ -1001,7 +997,7 @@
a = zeros((2, 2, 2))
r = str(a)
assert r == '[[[0.0 0.0]\n [0.0 0.0]]\n\n [[0.0 0.0]\n [0.0 0.0]]]'
-
+
def test_str_slice(self):
from numpypy import array, zeros
a = array(range(5), float)
More information about the pypy-commit
mailing list