[pypy-commit] pypy default: Added argmin and argmax to numpy arrays
justinpeel
noreply at buildbot.pypy.org
Tue Jul 12 00:47:15 CEST 2011
Author: Justin Peel <notmuchtotell at gmail.com>
Branch:
Changeset: r45487:dee586a9d707
Date: 2011-07-11 09:47 -0600
http://bitbucket.org/pypy/pypy/changeset/dee586a9d707/
Log: Added argmin and argmax to numpy arrays
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -181,6 +181,32 @@
return space.wrap(loop(self, self.eval(0), size))
return func_with_new_name(impl, "reduce_%s_impl" % function.__name__)
+ def _reduce_argmax_argmin_impl(function):
+ reduce_driver = jit.JitDriver(greens=['signature'],
+ reds = ['i', 'size', 'self', 'result'])
+ def loop(self, size):
+ result = 0
+ cur_best = self.eval(0)
+ i = 1
+ while i < size:
+ reduce_driver.jit_merge_point(signature=self.signature,
+ self=self, size=size, i=i,
+ result=result)
+ new_best = function(cur_best, self.eval(i))
+ if new_best != cur_best:
+ result = i
+ cur_best = new_best
+ i += 1
+ return result
+ def impl(self, space):
+ size = self.find_size()
+ if size == 0:
+ raise OperationError(space.w_ValueError,
+ space.wrap("Can't call %s on zero-size arrays" \
+ % function.__name__))
+ return space.wrap(loop(self, size))
+ return func_with_new_name(impl, "reduce_arg%s_impl" % function.__name__)
+
def _reduce_all_impl():
reduce_driver = jit.JitDriver(greens=['signature'],
reds = ['i', 'size', 'result', 'self'])
@@ -225,6 +251,8 @@
descr_prod = _reduce_sum_prod_impl(mul, 1.0)
descr_max = _reduce_max_min_impl(maximum)
descr_min = _reduce_max_min_impl(minimum)
+ descr_argmax = _reduce_argmax_argmin_impl(maximum)
+ descr_argmin = _reduce_argmax_argmin_impl(minimum)
descr_all = _reduce_all_impl()
descr_any = _reduce_any_impl()
@@ -519,6 +547,8 @@
prod = interp2app(BaseArray.descr_prod),
max = interp2app(BaseArray.descr_max),
min = interp2app(BaseArray.descr_min),
+ argmax = interp2app(BaseArray.descr_argmax),
+ argmin = interp2app(BaseArray.descr_argmin),
all = interp2app(BaseArray.descr_all),
any = interp2app(BaseArray.descr_any),
dot = interp2app(BaseArray.descr_dot),
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -313,6 +313,20 @@
b = array([])
raises(ValueError, "b.min()")
+ def test_argmax(self):
+ from numpy import array
+ a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
+ assert a.argmax() == 2
+ b = array([])
+ raises(ValueError, "b.argmax()")
+
+ def test_argmin(self):
+ from numpy import array
+ a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
+ assert a.argmin() == 3
+ b = array([])
+ raises(ValueError, "b.argmin()")
+
def test_all(self):
from numpy import array
a = array(range(5))
More information about the pypy-commit
mailing list