[Numpy-svn] r6325 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sat Jan 17 16:24:27 EST 2009


Author: ptvirtan
Date: 2009-01-17 15:24:13 -0600 (Sat, 17 Jan 2009)
New Revision: 6325

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
Log:
Make `trapz` accept 1-D `x` parameter for n-d `y`, even if axis != -1.

Additional tests included.

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2009-01-17 00:15:15 UTC (rev 6324)
+++ trunk/numpy/lib/function_base.py	2009-01-17 21:24:13 UTC (rev 6325)
@@ -2818,9 +2818,9 @@
     y : array_like
         Input array to integrate.
     x : array_like, optional
-        If `x` is None, then spacing between all `y` elements is 1.
+        If `x` is None, then spacing between all `y` elements is `dx`.
     dx : scalar, optional
-        If `x` is None, spacing given by `dx` is assumed.
+        If `x` is None, spacing given by `dx` is assumed. Default is 1.
     axis : int, optional
         Specify the axis.
 
@@ -2836,7 +2836,15 @@
     if x is None:
         d = dx
     else:
-        d = diff(x,axis=axis)
+        x = asarray(x)
+        if x.ndim == 1:
+            d = diff(x)
+            # reshape to correct shape
+            shape = [1]*y.ndim
+            shape[axis] = d.shape[0]
+            d = d.reshape(shape)
+        else:
+            d = diff(x, axis=axis)
     nd = len(y.shape)
     slice1 = [slice(None)]*nd
     slice2 = [slice(None)]*nd

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2009-01-17 00:15:15 UTC (rev 6324)
+++ trunk/numpy/lib/tests/test_function_base.py	2009-01-17 21:24:13 UTC (rev 6325)
@@ -430,6 +430,44 @@
         #check integral of normal equals 1
         assert_almost_equal(sum(r,axis=0),1,7)
 
+    def test_ndim(self):
+        x = linspace(0, 1, 3)
+        y = linspace(0, 2, 8)
+        z = linspace(0, 3, 13)
+
+        wx = ones_like(x) * (x[1]-x[0])
+        wx[0] /= 2
+        wx[-1] /= 2
+        wy = ones_like(y) * (y[1]-y[0])
+        wy[0] /= 2
+        wy[-1] /= 2
+        wz = ones_like(z) * (z[1]-z[0])
+        wz[0] /= 2
+        wz[-1] /= 2
+
+        q = x[:,None,None] + y[None,:,None] + z[None,None,:]
+
+        qx = (q*wx[:,None,None]).sum(axis=0)
+        qy = (q*wy[None,:,None]).sum(axis=1)
+        qz = (q*wz[None,None,:]).sum(axis=2)
+
+        # n-d `x`
+        r = trapz(q, x=x[:,None,None], axis=0)
+        assert_almost_equal(r, qx)
+        r = trapz(q, x=y[None,:,None], axis=1)
+        assert_almost_equal(r, qy)
+        r = trapz(q, x=z[None,None,:], axis=2)
+        assert_almost_equal(r, qz)
+
+        # 1-d `x`
+        r = trapz(q, x=x, axis=0)
+        assert_almost_equal(r, qx)
+        r = trapz(q, x=y, axis=1)
+        assert_almost_equal(r, qy)
+        r = trapz(q, x=z, axis=2)
+        assert_almost_equal(r, qz)
+
+
 class TestSinc(TestCase):
     def test_simple(self):
         assert(sinc(0)==1)




More information about the Numpy-svn mailing list