[Numpy-svn] r8552 - in branches/1.5.x/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Sat Jul 31 00:44:00 EDT 2010
Author: rgommers
Date: 2010-07-30 23:44:00 -0500 (Fri, 30 Jul 2010)
New Revision: 8552
Modified:
branches/1.5.x/numpy/lib/function_base.py
branches/1.5.x/numpy/lib/tests/test_function_base.py
Log:
ENH: (backport of r8551) Make trapz work with ndarray subclasses. Thanks to Ryan May.
Modified: branches/1.5.x/numpy/lib/function_base.py
===================================================================
--- branches/1.5.x/numpy/lib/function_base.py 2010-07-31 04:41:08 UTC (rev 8551)
+++ branches/1.5.x/numpy/lib/function_base.py 2010-07-31 04:44:00 UTC (rev 8552)
@@ -2924,7 +2924,7 @@
-----
Given a vector V of length N, the qth percentile of V is the qth ranked
value in a sorted copy of V. A weighted average of the two nearest neighbors
- is used if the normalized ranking does not match q exactly.
+ is used if the normalized ranking does not match q exactly.
The same as the median if q is 0.5; the same as the min if q is 0;
and the same as the max if q is 1
@@ -2962,7 +2962,7 @@
return a.min(axis=axis, out=out)
elif q == 100:
return a.max(axis=axis, out=out)
-
+
if overwrite_input:
if axis is None:
sorted = a.ravel()
@@ -3072,11 +3072,11 @@
array([ 2., 8.])
"""
- y = asarray(y)
+ y = asanyarray(y)
if x is None:
d = dx
else:
- x = asarray(x)
+ x = asanyarray(x)
if x.ndim == 1:
d = diff(x)
# reshape to correct shape
@@ -3090,7 +3090,13 @@
slice2 = [slice(None)]*nd
slice1[axis] = slice(1,None)
slice2[axis] = slice(None,-1)
- return add.reduce(d * (y[slice1]+y[slice2])/2.0,axis)
+ try:
+ ret = (d * (y[slice1] +y [slice2]) / 2.0).sum(axis)
+ except ValueError: # Operations didn't work, cast to ndarray
+ d = np.asarray(d)
+ y = np.asarray(y)
+ ret = add.reduce(d * (y[slice1]+y[slice2])/2.0, axis)
+ return ret
#always succeed
def add_newdoc(place, obj, doc):
Modified: branches/1.5.x/numpy/lib/tests/test_function_base.py
===================================================================
--- branches/1.5.x/numpy/lib/tests/test_function_base.py 2010-07-31 04:41:08 UTC (rev 8551)
+++ branches/1.5.x/numpy/lib/tests/test_function_base.py 2010-07-31 04:44:00 UTC (rev 8552)
@@ -491,7 +491,33 @@
r = trapz(q, x=z, axis=2)
assert_almost_equal(r, qz)
+ def test_masked(self):
+ #Testing that masked arrays behave as if the function is 0 where
+ #masked
+ x = arange(5)
+ y = x * x
+ mask = x == 2
+ ym = np.ma.array(y, mask=mask)
+ r = 13.0 # sum(0.5 * (0 + 1) * 1.0 + 0.5 * (9 + 16))
+ assert_almost_equal(trapz(ym, x), r)
+ xm = np.ma.array(x, mask=mask)
+ assert_almost_equal(trapz(ym, xm), r)
+
+ xm = np.ma.array(x, mask=mask)
+ assert_almost_equal(trapz(y, xm), r)
+
+ def test_matrix(self):
+ #Test to make sure matrices give the same answer as ndarrays
+ x = linspace(0, 5)
+ y = x * x
+ r = trapz(y, x)
+ mx = matrix(x)
+ my = matrix(y)
+ mr = trapz(my, mx)
+ assert_almost_equal(mr, r)
+
+
class TestSinc(TestCase):
def test_simple(self):
assert(sinc(0) == 1)
More information about the Numpy-svn
mailing list