[pypy-commit] pypy default: Added argmin and argmax to numpy arrays

justinpeel noreply at buildbot.pypy.org
Tue Jul 12 00:47:15 CEST 2011


Author: Justin Peel <notmuchtotell at gmail.com>
Branch: 
Changeset: r45487:dee586a9d707
Date: 2011-07-11 09:47 -0600
http://bitbucket.org/pypy/pypy/changeset/dee586a9d707/

Log:	Added argmin and argmax to numpy arrays

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -181,6 +181,32 @@
             return space.wrap(loop(self, self.eval(0), size))
         return func_with_new_name(impl, "reduce_%s_impl" % function.__name__)
 
+    def _reduce_argmax_argmin_impl(function):
+        reduce_driver = jit.JitDriver(greens=['signature'],
+                         reds = ['i', 'size', 'self', 'result'])
+        def loop(self, size):
+            result = 0
+            cur_best = self.eval(0)
+            i = 1
+            while i < size:
+                reduce_driver.jit_merge_point(signature=self.signature,
+                                              self=self, size=size, i=i,
+                                              result=result)
+                new_best = function(cur_best, self.eval(i))
+                if new_best != cur_best:
+                    result = i
+                    cur_best = new_best
+                i += 1
+            return result
+        def impl(self, space):
+            size = self.find_size()
+            if size == 0:
+                raise OperationError(space.w_ValueError,
+                    space.wrap("Can't call %s on zero-size arrays" \
+                            % function.__name__))
+            return space.wrap(loop(self, size))
+        return func_with_new_name(impl, "reduce_arg%s_impl" % function.__name__)
+
     def _reduce_all_impl():
         reduce_driver = jit.JitDriver(greens=['signature'],
                          reds = ['i', 'size', 'result', 'self'])
@@ -225,6 +251,8 @@
     descr_prod = _reduce_sum_prod_impl(mul, 1.0)
     descr_max = _reduce_max_min_impl(maximum)
     descr_min = _reduce_max_min_impl(minimum)
+    descr_argmax = _reduce_argmax_argmin_impl(maximum)
+    descr_argmin = _reduce_argmax_argmin_impl(minimum)
     descr_all = _reduce_all_impl()
     descr_any = _reduce_any_impl()
 
@@ -519,6 +547,8 @@
     prod = interp2app(BaseArray.descr_prod),
     max = interp2app(BaseArray.descr_max),
     min = interp2app(BaseArray.descr_min),
+    argmax = interp2app(BaseArray.descr_argmax),
+    argmin = interp2app(BaseArray.descr_argmin),
     all = interp2app(BaseArray.descr_all),
     any = interp2app(BaseArray.descr_any),
     dot = interp2app(BaseArray.descr_dot),
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -313,6 +313,20 @@
         b = array([])
         raises(ValueError, "b.min()")
 
+    def test_argmax(self):
+        from numpy import array
+        a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
+        assert a.argmax() == 2
+        b = array([])
+        raises(ValueError, "b.argmax()")
+
+    def test_argmin(self):
+        from numpy import array
+        a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
+        assert a.argmin() == 3
+        b = array([])
+        raises(ValueError, "b.argmin()")
+
     def test_all(self):
         from numpy import array
         a = array(range(5))


More information about the pypy-commit mailing list