[pypy-commit] pypy numpy-multidim: argmax works for multidim now

alex_gaynor noreply at buildbot.pypy.org
Thu Nov 24 14:43:01 CET 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-multidim
Changeset: r49728:4f7747b4ff1c
Date: 2011-11-24 07:42 -0600
http://bitbucket.org/pypy/pypy/changeset/4f7747b4ff1c/

Log:	argmax works for multidim now

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -389,36 +389,35 @@
 
     def _reduce_argmax_argmin_impl(op_name):
         reduce_driver = jit.JitDriver(greens=['shapelen', 'signature'],
-                         reds=['result', 'i', 'self', 'cur_best', 'dtype'])
+                         reds=['result', 'idx', 'i', 'self', 'cur_best', 'dtype'])
         def loop(self):
             i = self.start_iter()
-            result = i.get_offset()
             cur_best = self.eval(i)
             shapelen = len(self.shape)
             i = i.next(shapelen)
             dtype = self.find_dtype()
+            result = 0
+            idx = 1
             while not i.done():
                 reduce_driver.jit_merge_point(signature=self.signature,
                                               shapelen=shapelen,
                                               self=self, dtype=dtype,
-                                              i=i, result=result,
+                                              i=i, result=result, idx=idx,
                                               cur_best=cur_best)
                 new_best = getattr(dtype, op_name)(cur_best, self.eval(i))
                 if dtype.ne(new_best, cur_best):
-                    result = i.get_offset()
+                    result = idx
                     cur_best = new_best
                 i = i.next(shapelen)
+                idx += 1
             return result
         def impl(self, space):
             size = self.find_size()
-            if len(self.shape) > 1:
-                raise OperationError(space.w_TypeError,
-                                     space.wrap("argmin/max does not work on multidimensional arrays yet"))
             if size == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" \
                             % op_name))
-            return self.compute_index(space, loop(self))
+            return space.wrap(loop(self))
         return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
 
     def _all(self):
@@ -734,11 +733,6 @@
     def start_iter(self, res_shape=None):
         raise NotImplementedError
 
-    def compute_index(self, space, offset):
-        offset -= self.start
-        assert len(self.shape) == 1
-        return space.wrap(offset // self.strides[0])
-
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
         return w_obj
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -640,12 +640,7 @@
         r = a.argmax()
         assert r == 2
         b = array([])
-        try:
-            b.argmax()
-        except:
-            pass
-        else:
-            raise Exception("Did not raise")
+        raises(ValueError, b.argmax)
 
         a = array(range(-5, 5))
         r = a.argmax()
@@ -909,8 +904,9 @@
 
     def test_argmax(self):
         from numpypy import array
-        a = array([[1, 2], [3, 4]])
-        raises(TypeError, a.argmax)
+        a = array([[1, 2], [3, 4], [5, 6]])
+        assert a.argmax() == 5
+        assert a[:2,].argmax() == 3
 
 class AppTestSupport(object):
     def setup_class(cls):
@@ -1001,7 +997,7 @@
         a = zeros((2, 2, 2))
         r = str(a)
         assert r == '[[[0.0 0.0]\n  [0.0 0.0]]\n\n [[0.0 0.0]\n  [0.0 0.0]]]'
-        
+
     def test_str_slice(self):
         from numpypy import array, zeros
         a = array(range(5), float)


More information about the pypy-commit mailing list