[Scipy-svn] r4938 - in trunk/scipy/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Nov 2 20:42:39 EST 2008


Author: ptvirtan
Date: 2008-11-02 19:42:26 -0600 (Sun, 02 Nov 2008)
New Revision: 4938

Modified:
   trunk/scipy/interpolate/interpolate.py
   trunk/scipy/interpolate/tests/test_interpolate.py
Log:
scipy.interpolate: clarify interp2d docstrings, and fix broken logic in __init__ concerning meshgrid-like input. Addresses #703.

Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py	2008-11-03 01:42:04 UTC (rev 4937)
+++ trunk/scipy/interpolate/interpolate.py	2008-11-03 01:42:26 UTC (rev 4938)
@@ -8,7 +8,7 @@
 
 from numpy import shape, sometrue, rank, array, transpose, searchsorted, \
                   ones, logical_or, atleast_1d, atleast_2d, meshgrid, ravel, \
-                  dot, poly1d
+                  dot, poly1d, asarray
 import numpy as np
 import scipy.special as spec
 import math
@@ -47,38 +47,35 @@
     """
     interp2d(x, y, z, kind='linear', copy=True, bounds_error=False,
              fill_value=nan)
-    
+
     Interpolate over a 2D grid.
 
     Parameters
     ----------
-    x : 1D array
-    y : 1D array
+    x, y : 1D arrays
         Arrays defining the coordinates of a 2D grid.  If the
-        points lie on a regular grid, x and y can simply specify
-        the rows and colums, i.e.
+        points lie on a regular grid, `x` can specify the column coordinates
+        and `y` the row coordinates, e.g.::
 
-        x = [0,1,2]  y = [0,1,2]
+            x = [0,1,2];  y = [0,3,7]
 
-        otherwise x and y must specify the full coordinates, i.e.
+        otherwise x and y must specify the full coordinates, i.e.::
 
-        x = [0,1,2,0,1.5,2,0,1,2]  y = [0,1,2,0,1,2,0,1,2]
+            x = [0,1,2,0,1,2,0,1,2];  y = [0,0,0,3,3,3,7,7,7]
 
-        If x and y are 2-dimensional, they are flattened (allowing
-        the use of meshgrid, for example).
+        If `x` and `y` are multi-dimensional, they are flattened before use.
 
     z : 1D array
-        The values of the interpolated function on the grid
-        points. If z is a 2-dimensional array, it is flattened.
-
-    kind : 'linear', 'cubic', 'quintic'
+        The values of the interpolated function on the grid points. If
+        z is a multi-dimensional array, it is flattened before use.
+    kind : {'linear', 'cubic', 'quintic'}
         The kind of interpolation to use.
     copy : bool
         If True, then data is copied, otherwise only a reference is held.
     bounds_error : bool
-        If True, when interoplated values are requested outside of the
+        If True, when interpolated values are requested outside of the
         domain of the input data, an error is raised.
-        If False, then fill_value is used.
+        If False, then `fill_value` is used.
     fill_value : number
         If provided, the value to use for points outside of the
         interpolation domain. Defaults to NaN.
@@ -89,20 +86,20 @@
 
     See Also
     --------
-    bisplrep, bisplev - spline interpolation based on FITPACK
-    BivariateSpline - a more recent wrapper of the FITPACK routines
+    bisplrep, bisplev : spline interpolation based on FITPACK
+    BivariateSpline : a more recent wrapper of the FITPACK routines
+
     """
 
     def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
-        fill_value=np.nan):
-        self.x, self.y, self.z = map(ravel, map(array, [x, y, z]))
-        if not map(rank, [self.x, self.y, self.z]) == [1,1,1]:
-            raise ValueError("One of the input arrays is not 1-d.")
-        if len(self.x) != len(self.y):
-            raise ValueError("x and y must have equal lengths")
+                 fill_value=np.nan):
+        self.x, self.y, self.z = map(ravel, map(asarray, [x, y, z]))
+        
         if len(self.z) == len(self.x) * len(self.y):
             self.x, self.y = meshgrid(x,y)
             self.x, self.y = map(ravel, [self.x, self.y])
+        if len(self.x) != len(self.y):
+            raise ValueError("x and y must have equal lengths")
         if len(self.z) != len(self.x):
             raise ValueError("Invalid length for input z")
 
@@ -116,21 +113,24 @@
         self.tck = fitpack.bisplrep(self.x, self.y, self.z, kx=kx, ky=ky, s=0.)
 
     def __call__(self,x,y,dx=0,dy=0):
-        """ Interpolate the function.
+        """Interpolate the function.
 
         Parameters
         ----------
         x : 1D array
+            x-coordinates of the mesh on which to interpolate.
         y : 1D array
-            The points to interpolate.
+            y-coordinates of the mesh on which to interpolate.
         dx : int >= 0, < kx
+            Order of partial derivatives in x.
         dy : int >= 0, < ky
-            The order of partial derivatives in x and y, respectively.
+            Order of partial derivatives in y.
 
         Returns
         -------
         z : 2D array with shape (len(y), len(x))
             The interpolated values.
+        
         """
 
         x = atleast_1d(x)

Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py	2008-11-03 01:42:04 UTC (rev 4937)
+++ trunk/scipy/interpolate/tests/test_interpolate.py	2008-11-03 01:42:26 UTC (rev 4938)
@@ -1,5 +1,5 @@
 from numpy.testing import *
-from numpy import mgrid, pi, sin, ogrid, poly1d
+from numpy import mgrid, pi, sin, ogrid, poly1d, linspace
 import numpy as np
 
 from scipy.interpolate import interp1d, interp2d, lagrange
@@ -7,15 +7,22 @@
 
 class TestInterp2D(TestCase):
     def test_interp2d(self):
-        y, x = mgrid[0:pi:20j, 0:pi:21j]
-        z = sin(x+y)
+        y, x = mgrid[0:2:20j, 0:pi:21j]
+        z = sin(x+0.5*y)
         I = interp2d(x, y, z)
-        assert_almost_equal(I(1.0, 1.0), sin(2.0), decimal=2)
+        assert_almost_equal(I(1.0, 2.0), sin(2.0), decimal=2)
 
-        v,u = ogrid[0:pi:24j, 0:pi:25j]
-        assert_almost_equal(I(u.ravel(), v.ravel()), sin(v+u), decimal=2)
+        v,u = ogrid[0:2:24j, 0:pi:25j]
+        assert_almost_equal(I(u.ravel(), v.ravel()), sin(u+0.5*v), decimal=2)
+        
+    def test_interp2d_meshgrid_input(self):
+        # Ticket #703
+        x = linspace(0, 2, 16)
+        y = linspace(0, pi, 21)
+        z = sin(x[None,:] + y[:,None]/2.)
+        I = interp2d(x, y, z)
+        assert_almost_equal(I(1.0, 2.0), sin(2.0), decimal=2)
 
-
 class TestInterp1D(TestCase):
 
     def setUp(self):




More information about the Scipy-svn mailing list