[Numpy-svn] r8551 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sat Jul 31 00:41:08 EDT 2010


Author: rgommers
Date: 2010-07-30 23:41:08 -0500 (Fri, 30 Jul 2010)
New Revision: 8551

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
Log:
ENH: Make trapz work with ndarray subclasses. Thanks to Ryan May. Closes #1438.

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2010-07-29 20:58:05 UTC (rev 8550)
+++ trunk/numpy/lib/function_base.py	2010-07-31 04:41:08 UTC (rev 8551)
@@ -2924,7 +2924,7 @@
     -----
     Given a vector V of length N, the qth percentile of V is the qth ranked
     value in a sorted copy of V.  A weighted average of the two nearest neighbors
-    is used if the normalized ranking does not match q exactly. 
+    is used if the normalized ranking does not match q exactly.
     The same as the median if q is 0.5; the same as the min if q is 0;
     and the same as the max if q is 1
 
@@ -2962,7 +2962,7 @@
         return a.min(axis=axis, out=out)
     elif q == 100:
         return a.max(axis=axis, out=out)
-        
+
     if overwrite_input:
         if axis is None:
             sorted = a.ravel()
@@ -3072,11 +3072,11 @@
     array([ 2.,  8.])
 
     """
-    y = asarray(y)
+    y = asanyarray(y)
     if x is None:
         d = dx
     else:
-        x = asarray(x)
+        x = asanyarray(x)
         if x.ndim == 1:
             d = diff(x)
             # reshape to correct shape
@@ -3090,7 +3090,13 @@
     slice2 = [slice(None)]*nd
     slice1[axis] = slice(1,None)
     slice2[axis] = slice(None,-1)
-    return add.reduce(d * (y[slice1]+y[slice2])/2.0,axis)
+    try:
+        ret = (d * (y[slice1] +y [slice2]) / 2.0).sum(axis)
+    except ValueError: # Operations didn't work, cast to ndarray
+        d = np.asarray(d)
+        y = np.asarray(y)
+        ret = add.reduce(d * (y[slice1]+y[slice2])/2.0, axis)
+    return ret
 
 #always succeed
 def add_newdoc(place, obj, doc):

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2010-07-29 20:58:05 UTC (rev 8550)
+++ trunk/numpy/lib/tests/test_function_base.py	2010-07-31 04:41:08 UTC (rev 8551)
@@ -491,7 +491,33 @@
         r = trapz(q, x=z, axis=2)
         assert_almost_equal(r, qz)
 
+    def test_masked(self):
+        #Testing that masked arrays behave as if the function is 0 where
+        #masked
+        x = arange(5)
+        y = x * x
+        mask = x == 2
+        ym = np.ma.array(y, mask=mask)
+        r = 13.0 # sum(0.5 * (0 + 1) * 1.0 + 0.5 * (9 + 16))
+        assert_almost_equal(trapz(ym, x), r)
 
+        xm = np.ma.array(x, mask=mask)
+        assert_almost_equal(trapz(ym, xm), r)
+
+        xm = np.ma.array(x, mask=mask)
+        assert_almost_equal(trapz(y, xm), r)
+
+    def test_matrix(self):
+        #Test to make sure matrices give the same answer as ndarrays
+        x = linspace(0, 5)
+        y = x * x
+        r = trapz(y, x)
+        mx = matrix(x)
+        my = matrix(y)
+        mr = trapz(my, mx)
+        assert_almost_equal(mr, r)
+
+
 class TestSinc(TestCase):
     def test_simple(self):
         assert(sinc(0) == 1)




More information about the Numpy-svn mailing list