[Scipy-svn] r2220 - in trunk/Lib/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Sep 23 18:50:21 EDT 2006


Author: rkern
Date: 2006-09-23 17:50:16 -0500 (Sat, 23 Sep 2006)
New Revision: 2220

Modified:
   trunk/Lib/interpolate/interpolate.py
   trunk/Lib/interpolate/tests/test_interpolate.py
Log:
Cleaned up interp1d and added tests for it.

Modified: trunk/Lib/interpolate/interpolate.py
===================================================================
--- trunk/Lib/interpolate/interpolate.py	2006-09-23 21:56:21 UTC (rev 2219)
+++ trunk/Lib/interpolate/interpolate.py	2006-09-23 22:50:16 UTC (rev 2220)
@@ -1,14 +1,15 @@
-""" Class for interpolating values
-
-    !! Need to find argument for keeping initialize.  If it isn't
-    !! found, get rid of it!
+""" Classes for interpolating values.
 """
 
+# !! Need to find argument for keeping initialize.  If it isn't
+# !! found, get rid of it!
+
 __all__ = ['interp1d', 'interp2d']
 
 from numpy import shape, sometrue, rank, array, transpose, \
      swapaxes, searchsorted, clip, take, ones, putmask, less, greater, \
      logical_or, atleast_1d, atleast_2d
+import numpy as np
 
 import fitpack
 
@@ -18,50 +19,81 @@
         all = sometrue(all,axis=0)
     return all
 
-class interp2d:
-    def __init__(self, x, y, z, kind='linear',
-                 copy=True, bounds_error=False, fill_value=None):
+class interp2d(object):
+    """ Interpolate over a 2D grid.
+
+    See Also
+    --------
+    bisplrep, bisplev - spline interpolation based on FITPACK
+    BivariateSpline - a more recent wrapper of the FITPACK routines
+    """
+
+    def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
+        fill_value=np.nan):
+        """ Initialize a 2D interpolator.
+
+        Parameters
+        ----------
+        x : 1D array or 2D meshgrid array
+        y : 1D array or 2D meshgrid array
+            Arrays defining the coordinates of a 2D grid.
+        z : 2D array
+            The values of the interpolated function on the grid points.
+        kind : 'linear', 'cubic', 'quintic'
+            The kind of interpolation to use.
+        copy : bool
+            If True, then data is copied, otherwise only a reference is held.
+        bounds_error : bool
+            If True, when interoplated values are requested outside of the
+            domain of the input data, an error is raised.
+            If False, then fill_value is used.
+        fill_value : number
+            If provided, the value to use for points outside of the
+            interpolation domain. Defaults to NaN.
+
+        Raises
+        ------
+        ValueError when inputs are invalid.
+
         """
-        Input:
-          x,y  - 1-d arrays defining 2-d grid (or 2-d meshgrid arrays)
-          z    - 2-d array of grid values
-          kind - interpolation type ('linear', 'cubic', 'quintic')
-          copy - if true then data is copied into class, otherwise only a
-                   reference is held.
-          bounds_error - if true, then when out_of_bounds occurs, an error is
-                          raised otherwise, the output is filled with
-                          fill_value.
-          fill_value - if None, then NaN, otherwise the value to fill in
-                        outside defined region.
-        """
+
         self.x = atleast_1d(x).copy()
         self.y = atleast_1d(y).copy()
         if rank(self.x) > 2 or rank(self.y) > 2:
-            raise ValueError, "One of the input arrays is not 1-d or 2-d."
+            raise ValueError("One of the input arrays is not 1-d or 2-d.")
         if rank(self.x) == 2:
             self.x = self.x[:,0]
         if rank(self.y) == 2:
             self.y = self.y[0]
         self.z = array(z,copy=True)
         if rank(z) != 2:
-            raise ValueError, "Grid values is not a 2-d array."
+            raise ValueError("Grid values is not a 2-d array.")
         try:
             kx = ky = {'linear' : 1,
                        'cubic' : 3,
                        'quintic' : 5}[kind]
-        except:
-            raise ValueError, "Unsupported interpolation type."
+        except KeyError:
+            raise ValueError("Unsupported interpolation type.")
         self.tck = fitpack.bisplrep(x, y, z, kx=kx, ky=ky, s=0.)
 
     def __call__(self,x,y,dx=0,dy=0):
+        """ Interpolate the function.
+
+        Parameters
+        ----------
+        x : 1D array
+        y : 1D array
+            The points to interpolate.
+        dx : int >= 0, < kx
+        dy : int >= 0, < ky
+            The order of partial derivatives in x and y, respectively.
+
+        Returns
+        -------
+        z : 2D array with shape (len(y), len(x))
+            The interpolated values.
         """
-        Input:
-          x,y   - 1-d arrays defining points to interpolate.
-          dx,dy - order of partial derivatives in x and y, respectively.
-                  0<=dx<kx, 0<=dy<ky
-        Output:
-          z     - 2-d array of interpolated values of shape (len(y), len(x)).
-        """
+
         x = atleast_1d(x)
         y = atleast_1d(y)
         z = fitpack.bisplev(x, y, self.tck, dx, dy)
@@ -71,161 +103,181 @@
             z = z[0]
         return array(z)
 
-class interp1d:
-    interp_axis = -1 # used to set which is default interpolation
-                     # axis.  DO NOT CHANGE OR CODE WILL BREAK.
 
-    def __init__(self,x,y,kind='linear',axis = -1,
-                 copy = 1,bounds_error=1, fill_value=None):
-        """Initialize a 1d linear interpolation class
+class interp1d(object):
+    """ Interpolate a 1D function.
+    
+    See Also
+    --------
+    splrep, splev - spline interpolation based on FITPACK
+    UnivariateSpline - a more recent wrapper of the FITPACK routines
+    """
 
-        Description:
-          x and y are arrays of values used to approximate some function f:
+    _interp_axis = -1 # used to set which is default interpolation
+                      # axis.  DO NOT CHANGE OR CODE WILL BREAK.
+
+    def __init__(self, x, y, kind='linear', axis=-1,
+                 copy=True, bounds_error=True, fill_value=np.nan):
+        """ Initialize a 1D linear interpolation class.
+
+        Description
+        -----------
+        x and y are arrays of values used to approximate some function f:
             y = f(x)
-          This class returns a function whose call method uses linear
-          interpolation to find the value of new points.
+        This class returns a function whose call method uses linear
+        interpolation to find the value of new points.
 
-        Inputs:
-            x -- a 1d array of monotonically increasing real values.
-                 x cannot include duplicate values. (otherwise f is
-                 overspecified)
-            y -- an nd array of real values.  y's length along the
-                 interpolation axis must be equal to the length
-                 of x.
-            kind -- specify the kind of interpolation: 'nearest', 'linear',
-                    'cubic', or 'spline'
-            axis -- specifies the axis of y along which to
-                    interpolate. Interpolation defaults to the last
-                    axis of y.  (default: -1)
-            copy -- If 1, the class makes internal copies of x and y.
-                    If 0, references to x and y are used. The default
-                    is to copy. (default: 1)
-            bounds_error -- If 1, an error is thrown any time interpolation
-                            is attempted on a value outside of the range
-                            of x (where extrapolation is necessary).
-                            If 0, out of bounds values are assigned the
-                            NaN (#INF) value.  By default, an error is
-                            raised, although this is prone to change.
-                            (default: 1)
+        Parameters
+        ----------
+        x : array
+            A 1D array of monotonically increasing real values.  x cannot
+            include duplicate values (otherwise f is overspecified)
+        y : array
+            An N-D array of real values.  y's length along the interpolation
+            axis must be equal to the length of x.
+        kind : str
+            Specifies the kind of interpolation. At the moment, only 'linear' is
+            implemented.
+        axis : int
+            Specifies the axis of y along which to interpolate. Interpolation
+            defaults to the last axis of y.
+        copy : bool
+            If True, the class makes internal copies of x and y.  
+            If False, references to x and y are used.
+            The default is to copy.
+        bounds_error : bool
+            If True, an error is thrown any time interpolation is attempted on
+            a value outside of the range of x (where extrapolation is
+            necessary).
+            If False, out of bounds values are assigned fill_value.
+            By default, an error is raised.
+        fill_value : float
+            If provided, then this value will be used to fill in for requested
+            points outside of the data range.
+            If not provided, then the default is NaN.
         """
-        self.axis = axis
+
         self.copy = copy
         self.bounds_error = bounds_error
-        if fill_value is None:
-            self.fill_value = array(0.0) / array(0.0)
-        else:
-            self.fill_value = fill_value
+        self.fill_value = fill_value
 
         if kind != 'linear':
-            raise NotImplementedError, "Only linear supported for now. Use "\
-                                       "fitpack routines for other types."
+            raise NotImplementedError("Only linear supported for now. Use "
+                "fitpack routines for other types.")
 
-        # Check that both x and y are at least 1 dimensional.
-        if len(shape(x)) == 0 or len(shape(y)) == 0:
-            raise ValueError, "x and y arrays must have at least one dimension."
-        # make a "view" of the y array that is rotated to the
-        # interpolation axis.
-        oriented_x = x
-        oriented_y = swapaxes(y,self.interp_axis,axis)
-        interp_axis = self.interp_axis
-        len_x = shape(oriented_x)[interp_axis]
-        len_y = shape(oriented_y)[interp_axis]
+        x = array(x, copy=self.copy)
+        y = array(y, copy=self.copy)
+
+        if len(x.shape) != 1:
+            raise ValueError("the x array must have exactly one dimension.")
+        if len(y.shape) == 0:
+            raise ValueError("the y array must have at least one dimension.")
+
+        # Normalize the axis to ensure that it is positive.
+        self.axis = axis % len(y.shape)
+
+        # Make a "view" of the y array that is rotated to the interpolation
+        # axis.
+        oriented_y = y.swapaxes(self._interp_axis, axis)
+        len_x = len(x)
+        len_y = oriented_y.shape[self._interp_axis]
         if len_x != len_y:
-            raise ValueError, "x and y arrays must be equal in length along "\
-                              "interpolation axis."
+            raise ValueError("x and y arrays must be equal in length along"
+                "interpolation axis.")
         if len_x < 2 or len_y < 2:
-            raise ValueError, "x and y arrays must have more than 1 entry"
-        self.x = array(oriented_x,copy=self.copy)
-        self.y = array(oriented_y,copy=self.copy)
+            raise ValueError("x and y arrays must have more than 1 entry")
+        self.x = x
+        self.y = oriented_y
 
-    def __call__(self,x_new):
-        """Find linearly interpolated y_new = <name>(x_new).
+    def __call__(self, x_new):
+        """ Find linearly interpolated y_new = f(x_new).
 
-        Inputs:
-          x_new -- New independent variables.
+        Parameters
+        ----------
+        x_new : number or array
+            New independent variable(s).
 
-        Outputs:
-          y_new -- Linearly interpolated values corresponding to x_new.
+        Returns
+        -------
+        y_new : number or array
+            Linearly interpolated value(s) corresponding to x_new.
         """
+
         # 1. Handle values in x_new that are outside of x.  Throw error,
         #    or return a list of mask array indicating the outofbounds values.
         #    The behavior is set by the bounds_error variable.
         x_new = atleast_1d(x_new)
         out_of_bounds = self._check_bounds(x_new)
+
         # 2. Find where in the orignal data, the values to interpolate
         #    would be inserted.
-        #    Note: If x_new[n] = x[m], then m is returned by searchsorted.
-        x_new_indices = searchsorted(self.x,x_new)
+        #    Note: If x_new[n] == x[m], then m is returned by searchsorted.
+        x_new_indices = searchsorted(self.x, x_new)
+
         # 3. Clip x_new_indices so that they are within the range of
         #    self.x indices and at least 1.  Removes mis-interpolation
         #    of x_new[n] = x[0]
-        x_new_indices = clip(x_new_indices,1,len(self.x)-1).astype(int)
+        x_new_indices = x_new_indices.clip(1, len(self.x)-1).astype(int)
+
         # 4. Calculate the slope of regions that each x_new value falls in.
-        lo = x_new_indices - 1; hi = x_new_indices
+        lo = x_new_indices - 1
+        hi = x_new_indices
 
-        # !! take(,axis=0) should default to the last axis (IMHO) and remove
-        # !! the extra argument.
-        x_lo = take(self.x,lo,axis=self.interp_axis)
-        x_hi = take(self.x,hi,axis=self.interp_axis);
-        y_lo = take(self.y,lo,axis=self.interp_axis)
-        y_hi = take(self.y,hi,axis=self.interp_axis);
-        slope = (y_hi-y_lo)/(x_hi-x_lo)
+        x_lo = self.x[lo]
+        x_hi = self.x[hi]
+        y_lo = self.y[..., lo]
+        y_hi = self.y[..., hi]
+
+        # Note that the following two expressions rely on the specifics of the
+        # broadcasting semantics.
+        slope = (y_hi-y_lo) / (x_hi-x_lo)
+
         # 5. Calculate the actual value for each entry in x_new.
         y_new = slope*(x_new-x_lo) + y_lo
-        # 6. Fill any values that were out of bounds with NaN
-        # !! Need to think about how to do this efficiently for
-        # !! mutli-dimensional Cases.
-        yshape = y_new.shape
-        y_new = y_new.ravel()
-        new_shape = list(yshape)
-        new_shape[self.interp_axis] = 1
-        sec_shape = [1]*len(new_shape)
-        sec_shape[self.interp_axis] = len(out_of_bounds)
-        out_of_bounds.shape = sec_shape
-        new_out = ones(new_shape)*out_of_bounds
-        putmask(y_new, new_out.ravel(), self.fill_value)
-        y_new.shape = yshape
-        # Rotate the values of y_new back so that they coorespond to the
-        # correct x_new values.
-        result = swapaxes(y_new,self.interp_axis,self.axis)
-        try:
-            len(x_new)
-            return result
-        except TypeError:
-            return result[0]
+
+        # 6. Fill any values that were out of bounds with fill_value.
+        y_new[..., out_of_bounds] = self.fill_value
+
+        # Rotate the values of y_new back so that they correspond to the
+        # correct x_new values. For N-D x_new, take the last N axes from y_new
+        # and insert them where self.axis was in the list of axes.
+        nx = len(x_new.shape)
+        ny = len(y_new.shape)
+        axes = range(ny - nx)
+        axes[self.axis:self.axis] = range(ny - nx, ny)
+        result = y_new.transpose(axes)
+
         return result
 
-    def _check_bounds(self,x_new):
-        # If self.bounds_error = 1, we raise an error if any x_new values
+    def _check_bounds(self, x_new):
+        """ Check the inputs for being in the bounds of the interpolated data.
+
+        Parameters
+        ----------
+        x_new : array
+
+        Returns
+        -------
+        out_of_bounds : bool array
+            The mask on x_new of values that are out of the bounds.
+        """
+
+        # If self.bounds_error is True, we raise an error if any x_new values
         # fall outside the range of x.  Otherwise, we return an array indicating
         # which values are outside the boundary region.
-        # !! Needs some work for multi-dimensional x !!
-        below_bounds = less(x_new,self.x[0])
-        above_bounds = greater(x_new,self.x[-1])
-        #  Note: sometrue has been redefined to handle length 0 arrays
+        below_bounds = x_new < self.x[0]
+        above_bounds = x_new > self.x[-1]
+
         # !! Could provide more information about which values are out of bounds
-        if self.bounds_error and sometrue(below_bounds,axis=0):
-            raise ValueError, " A value in x_new is below the"\
-                              " interpolation range."
-        if self.bounds_error and sometrue(above_bounds,axis=0):
-            raise ValueError, " A value in x_new is above the"\
-                              " interpolation range."
-        # !! Should we emit a warning if some values are out of bounds.
+        if self.bounds_error and below_bounds.any():
+            raise ValueError("A value in x_new is below the interpolation "
+                "range.")
+        if self.bounds_error and above_bounds.any():
+            raise ValueError("A value in x_new is above the interpolation "
+                "range.")
+
+        # !! Should we emit a warning if some values are out of bounds?
         # !! matlab does not.
-        out_of_bounds = logical_or(below_bounds,above_bounds)
+        out_of_bounds = logical_or(below_bounds, above_bounds)
         return out_of_bounds
 
-    def model_error(self,x_new,y_new):
-        # How well do x_new,yy points fit the model?
-        # Return an array of error values.
-        pass
-
-
-#assumes module test_xxx is in python path
-#def test():
-#    test_module = 'test_' + __name__ # __name__ is name of this module
-#    test_string = 'import %s;reload(%s);%s.test()' % ((test_module,)*3)
-#    exec(test_string)
-
-#if __name__ == '__main__':
-#    test()

Modified: trunk/Lib/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/Lib/interpolate/tests/test_interpolate.py	2006-09-23 21:56:21 UTC (rev 2219)
+++ trunk/Lib/interpolate/tests/test_interpolate.py	2006-09-23 22:50:16 UTC (rev 2220)
@@ -1,9 +1,12 @@
 from numpy.testing import *
 from numpy import mgrid, pi, sin, ogrid
+import numpy as np
+
 set_package_path()
 from interpolate import interp1d, interp2d
 restore_path()
 
+
 class test_interp2d(ScipyTestCase):
     def test_interp2d(self):
         y, x = mgrid[0:pi:20j, 0:pi:21j]
@@ -14,5 +17,182 @@
 	v,u = ogrid[0:pi:24j, 0:pi:25j]
         assert_almost_equal(I(u.ravel(), v.ravel()), sin(v+u), decimal=2)
 
+
+class test_interp1d(ScipyTestCase):
+
+    def setUp(self):
+        self.x10 = np.arange(10.)
+        self.y10 = np.arange(10.)
+        self.x25 = self.x10.reshape((2,5))
+        self.x2 = np.arange(2.)
+        self.y2 = np.arange(2.)
+        self.x1 = np.array([0.])
+        self.y1 = np.array([0.])
+
+        self.y210 = np.arange(20.).reshape((2, 10))
+        self.y102 = np.arange(20.).reshape((10, 2))
+
+        self.fill_value = -100.0
+
+    def test_validation(self):
+        """ Make sure that appropriate exceptions are raised when invalid values
+        are given to the constructor.
+        """
+
+        # Only kind='linear' is implemented.
+        self.assertRaises(NotImplementedError, interp1d, self.x10, self.y10, kind='cubic')
+        interp1d(self.x10, self.y10, kind='linear')
+
+        # x array must be 1D.
+        self.assertRaises(ValueError, interp1d, self.x25, self.y10)
+
+        # y array cannot be a scalar.
+        self.assertRaises(ValueError, interp1d, self.x10, np.array(0))
+
+        # Check for x and y arrays having the same length.
+        self.assertRaises(ValueError, interp1d, self.x10, self.y2)
+        self.assertRaises(ValueError, interp1d, self.x2, self.y10)
+        self.assertRaises(ValueError, interp1d, self.x10, self.y102)
+        interp1d(self.x10, self.y210)
+        interp1d(self.x10, self.y102, axis=0)
+
+        # Check for x and y having at least 1 element.
+        self.assertRaises(ValueError, interp1d, self.x1, self.y10)
+        self.assertRaises(ValueError, interp1d, self.x10, self.y1)
+        self.assertRaises(ValueError, interp1d, self.x1, self.y1)
+
+
+    def test_init(self):
+        """ Check that the attributes are initialized appropriately by the
+        constructor.
+        """
+
+        self.assert_(interp1d(self.x10, self.y10).copy)
+        self.assert_(not interp1d(self.x10, self.y10, copy=False).copy)
+        self.assert_(interp1d(self.x10, self.y10).bounds_error)
+        self.assert_(not interp1d(self.x10, self.y10, bounds_error=False).bounds_error)
+        self.assert_(np.isnan(interp1d(self.x10, self.y10).fill_value))
+        self.assertEqual(
+            interp1d(self.x10, self.y10, fill_value=3.0).fill_value, 
+            3.0,
+        )
+        self.assertEqual(
+            interp1d(self.x10, self.y10).axis,
+            0,
+        )
+        self.assertEqual(
+            interp1d(self.x10, self.y210).axis,
+            1,
+        )
+        self.assertEqual(
+            interp1d(self.x10, self.y102, axis=0).axis,
+            0,
+        )
+        assert_array_equal(
+            interp1d(self.x10, self.y10).x,
+            self.x10,
+        )
+        assert_array_equal(
+            interp1d(self.x10, self.y10).y,
+            self.y10,
+        )
+        assert_array_equal(
+            interp1d(self.x10, self.y210).y,
+            self.y210,
+        )
+
+
+    def test_linear(self):
+        """ Check the actual implementation of linear interpolation.
+        """
+
+        interp10 = interp1d(self.x10, self.y10)
+        assert_array_almost_equal(
+            interp10(self.x10),
+            self.y10,
+        )
+        assert_array_almost_equal(
+            interp10(1.2),
+            np.array([1.2]),
+        )
+        assert_array_almost_equal(
+            interp10([2.4, 5.6, 6.0]),
+            np.array([2.4, 5.6, 6.0]),
+        )
+
+
+    def test_bounds(self):
+        """ Test that our handling of out-of-bounds input is correct.
+        """
+
+        extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
+            bounds_error=False)
+        assert_array_equal(
+            extrap10(11.2),
+            np.array([self.fill_value]),
+        )
+        assert_array_equal(
+            extrap10(-3.4),
+            np.array([self.fill_value]),
+        )
+        assert_array_equal(
+            extrap10._check_bounds(np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
+            np.array([True, False, False, False, True]),
+        )
+
+        raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True)
+        self.assertRaises(ValueError, raises_bounds_error, -1.0)
+        self.assertRaises(ValueError, raises_bounds_error, 11.0)
+        raises_bounds_error([0.0, 5.0, 9.0])
+
+
+    def test_nd(self):
+        """ Check the behavior when the inputs and outputs are multidimensional.
+        """
+
+        # Multidimensional input.
+        interp10 = interp1d(self.x10, self.y10)
+        assert_array_almost_equal(
+            interp10(np.array([[3.4, 5.6], [2.4, 7.8]])),
+            np.array([[3.4, 5.6], [2.4, 7.8]]),
+        )
+
+        # Multidimensional outputs.
+        interp210 = interp1d(self.x10, self.y210)
+        assert_array_almost_equal(
+            interp210(1.5),
+            np.array([[1.5], [11.5]]),
+        )
+        assert_array_almost_equal(
+            interp210(np.array([1.5, 2.4])),
+            np.array([[1.5, 2.4],
+                      [11.5, 12.4]]),
+        )
+        
+        interp102 = interp1d(self.x10, self.y102, axis=0)
+        assert_array_almost_equal(
+            interp102(1.5),
+            np.array([[3.0, 4.0]]),
+        )
+        assert_array_almost_equal(
+            interp102(np.array([1.5, 2.4])),
+            np.array([[3.0, 4.0],
+                      [4.8, 5.8]]),
+        )
+
+        # Both at the same time!
+        x_new = np.array([[3.4, 5.6], [2.4, 7.8]])
+        assert_array_almost_equal(
+            interp210(x_new),
+            np.array([[[3.4, 5.6], [2.4, 7.8]],
+                      [[13.4, 15.6], [12.4, 17.8]]]),
+        )
+        assert_array_almost_equal(
+            interp102(x_new),
+            np.array([[[6.8, 7.8], [11.2, 12.2]],
+                      [[4.8, 5.8], [15.6, 16.6]]]),
+        )
+
+
 if __name__ == "__main__":
     ScipyTest().run()




More information about the Scipy-svn mailing list