[Scipy-svn] r2220 - in trunk/Lib/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Sep 23 18:50:21 EDT 2006
Author: rkern
Date: 2006-09-23 17:50:16 -0500 (Sat, 23 Sep 2006)
New Revision: 2220
Modified:
trunk/Lib/interpolate/interpolate.py
trunk/Lib/interpolate/tests/test_interpolate.py
Log:
Cleaned up interp1d and added tests for it.
Modified: trunk/Lib/interpolate/interpolate.py
===================================================================
--- trunk/Lib/interpolate/interpolate.py 2006-09-23 21:56:21 UTC (rev 2219)
+++ trunk/Lib/interpolate/interpolate.py 2006-09-23 22:50:16 UTC (rev 2220)
@@ -1,14 +1,15 @@
-""" Class for interpolating values
-
- !! Need to find argument for keeping initialize. If it isn't
- !! found, get rid of it!
+""" Classes for interpolating values.
"""
+# !! Need to find argument for keeping initialize. If it isn't
+# !! found, get rid of it!
+
__all__ = ['interp1d', 'interp2d']
from numpy import shape, sometrue, rank, array, transpose, \
swapaxes, searchsorted, clip, take, ones, putmask, less, greater, \
logical_or, atleast_1d, atleast_2d
+import numpy as np
import fitpack
@@ -18,50 +19,81 @@
all = sometrue(all,axis=0)
return all
-class interp2d:
- def __init__(self, x, y, z, kind='linear',
- copy=True, bounds_error=False, fill_value=None):
+class interp2d(object):
+ """ Interpolate over a 2D grid.
+
+ See Also
+ --------
+ bisplrep, bisplev - spline interpolation based on FITPACK
+ BivariateSpline - a more recent wrapper of the FITPACK routines
+ """
+
+ def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
+ fill_value=np.nan):
+ """ Initialize a 2D interpolator.
+
+ Parameters
+ ----------
+ x : 1D array or 2D meshgrid array
+ y : 1D array or 2D meshgrid array
+ Arrays defining the coordinates of a 2D grid.
+ z : 2D array
+ The values of the interpolated function on the grid points.
+ kind : 'linear', 'cubic', 'quintic'
+ The kind of interpolation to use.
+ copy : bool
+ If True, then data is copied, otherwise only a reference is held.
+ bounds_error : bool
+ If True, when interoplated values are requested outside of the
+ domain of the input data, an error is raised.
+ If False, then fill_value is used.
+ fill_value : number
+ If provided, the value to use for points outside of the
+ interpolation domain. Defaults to NaN.
+
+ Raises
+ ------
+ ValueError when inputs are invalid.
+
"""
- Input:
- x,y - 1-d arrays defining 2-d grid (or 2-d meshgrid arrays)
- z - 2-d array of grid values
- kind - interpolation type ('linear', 'cubic', 'quintic')
- copy - if true then data is copied into class, otherwise only a
- reference is held.
- bounds_error - if true, then when out_of_bounds occurs, an error is
- raised otherwise, the output is filled with
- fill_value.
- fill_value - if None, then NaN, otherwise the value to fill in
- outside defined region.
- """
+
self.x = atleast_1d(x).copy()
self.y = atleast_1d(y).copy()
if rank(self.x) > 2 or rank(self.y) > 2:
- raise ValueError, "One of the input arrays is not 1-d or 2-d."
+ raise ValueError("One of the input arrays is not 1-d or 2-d.")
if rank(self.x) == 2:
self.x = self.x[:,0]
if rank(self.y) == 2:
self.y = self.y[0]
self.z = array(z,copy=True)
if rank(z) != 2:
- raise ValueError, "Grid values is not a 2-d array."
+ raise ValueError("Grid values is not a 2-d array.")
try:
kx = ky = {'linear' : 1,
'cubic' : 3,
'quintic' : 5}[kind]
- except:
- raise ValueError, "Unsupported interpolation type."
+ except KeyError:
+ raise ValueError("Unsupported interpolation type.")
self.tck = fitpack.bisplrep(x, y, z, kx=kx, ky=ky, s=0.)
def __call__(self,x,y,dx=0,dy=0):
+ """ Interpolate the function.
+
+ Parameters
+ ----------
+ x : 1D array
+ y : 1D array
+ The points to interpolate.
+ dx : int >= 0, < kx
+ dy : int >= 0, < ky
+ The order of partial derivatives in x and y, respectively.
+
+ Returns
+ -------
+ z : 2D array with shape (len(y), len(x))
+ The interpolated values.
"""
- Input:
- x,y - 1-d arrays defining points to interpolate.
- dx,dy - order of partial derivatives in x and y, respectively.
- 0<=dx<kx, 0<=dy<ky
- Output:
- z - 2-d array of interpolated values of shape (len(y), len(x)).
- """
+
x = atleast_1d(x)
y = atleast_1d(y)
z = fitpack.bisplev(x, y, self.tck, dx, dy)
@@ -71,161 +103,181 @@
z = z[0]
return array(z)
-class interp1d:
- interp_axis = -1 # used to set which is default interpolation
- # axis. DO NOT CHANGE OR CODE WILL BREAK.
- def __init__(self,x,y,kind='linear',axis = -1,
- copy = 1,bounds_error=1, fill_value=None):
- """Initialize a 1d linear interpolation class
+class interp1d(object):
+ """ Interpolate a 1D function.
+
+ See Also
+ --------
+ splrep, splev - spline interpolation based on FITPACK
+ UnivariateSpline - a more recent wrapper of the FITPACK routines
+ """
- Description:
- x and y are arrays of values used to approximate some function f:
+ _interp_axis = -1 # used to set which is default interpolation
+ # axis. DO NOT CHANGE OR CODE WILL BREAK.
+
+ def __init__(self, x, y, kind='linear', axis=-1,
+ copy=True, bounds_error=True, fill_value=np.nan):
+ """ Initialize a 1D linear interpolation class.
+
+ Description
+ -----------
+ x and y are arrays of values used to approximate some function f:
y = f(x)
- This class returns a function whose call method uses linear
- interpolation to find the value of new points.
+ This class returns a function whose call method uses linear
+ interpolation to find the value of new points.
- Inputs:
- x -- a 1d array of monotonically increasing real values.
- x cannot include duplicate values. (otherwise f is
- overspecified)
- y -- an nd array of real values. y's length along the
- interpolation axis must be equal to the length
- of x.
- kind -- specify the kind of interpolation: 'nearest', 'linear',
- 'cubic', or 'spline'
- axis -- specifies the axis of y along which to
- interpolate. Interpolation defaults to the last
- axis of y. (default: -1)
- copy -- If 1, the class makes internal copies of x and y.
- If 0, references to x and y are used. The default
- is to copy. (default: 1)
- bounds_error -- If 1, an error is thrown any time interpolation
- is attempted on a value outside of the range
- of x (where extrapolation is necessary).
- If 0, out of bounds values are assigned the
- NaN (#INF) value. By default, an error is
- raised, although this is prone to change.
- (default: 1)
+ Parameters
+ ----------
+ x : array
+ A 1D array of monotonically increasing real values. x cannot
+ include duplicate values (otherwise f is overspecified)
+ y : array
+ An N-D array of real values. y's length along the interpolation
+ axis must be equal to the length of x.
+ kind : str
+ Specifies the kind of interpolation. At the moment, only 'linear' is
+ implemented.
+ axis : int
+ Specifies the axis of y along which to interpolate. Interpolation
+ defaults to the last axis of y.
+ copy : bool
+ If True, the class makes internal copies of x and y.
+ If False, references to x and y are used.
+ The default is to copy.
+ bounds_error : bool
+ If True, an error is thrown any time interpolation is attempted on
+ a value outside of the range of x (where extrapolation is
+ necessary).
+ If False, out of bounds values are assigned fill_value.
+ By default, an error is raised.
+ fill_value : float
+ If provided, then this value will be used to fill in for requested
+ points outside of the data range.
+ If not provided, then the default is NaN.
"""
- self.axis = axis
+
self.copy = copy
self.bounds_error = bounds_error
- if fill_value is None:
- self.fill_value = array(0.0) / array(0.0)
- else:
- self.fill_value = fill_value
+ self.fill_value = fill_value
if kind != 'linear':
- raise NotImplementedError, "Only linear supported for now. Use "\
- "fitpack routines for other types."
+ raise NotImplementedError("Only linear supported for now. Use "
+ "fitpack routines for other types.")
- # Check that both x and y are at least 1 dimensional.
- if len(shape(x)) == 0 or len(shape(y)) == 0:
- raise ValueError, "x and y arrays must have at least one dimension."
- # make a "view" of the y array that is rotated to the
- # interpolation axis.
- oriented_x = x
- oriented_y = swapaxes(y,self.interp_axis,axis)
- interp_axis = self.interp_axis
- len_x = shape(oriented_x)[interp_axis]
- len_y = shape(oriented_y)[interp_axis]
+ x = array(x, copy=self.copy)
+ y = array(y, copy=self.copy)
+
+ if len(x.shape) != 1:
+ raise ValueError("the x array must have exactly one dimension.")
+ if len(y.shape) == 0:
+ raise ValueError("the y array must have at least one dimension.")
+
+ # Normalize the axis to ensure that it is positive.
+ self.axis = axis % len(y.shape)
+
+ # Make a "view" of the y array that is rotated to the interpolation
+ # axis.
+ oriented_y = y.swapaxes(self._interp_axis, axis)
+ len_x = len(x)
+ len_y = oriented_y.shape[self._interp_axis]
if len_x != len_y:
- raise ValueError, "x and y arrays must be equal in length along "\
- "interpolation axis."
+ raise ValueError("x and y arrays must be equal in length along"
+ "interpolation axis.")
if len_x < 2 or len_y < 2:
- raise ValueError, "x and y arrays must have more than 1 entry"
- self.x = array(oriented_x,copy=self.copy)
- self.y = array(oriented_y,copy=self.copy)
+ raise ValueError("x and y arrays must have more than 1 entry")
+ self.x = x
+ self.y = oriented_y
- def __call__(self,x_new):
- """Find linearly interpolated y_new = <name>(x_new).
+ def __call__(self, x_new):
+ """ Find linearly interpolated y_new = f(x_new).
- Inputs:
- x_new -- New independent variables.
+ Parameters
+ ----------
+ x_new : number or array
+ New independent variable(s).
- Outputs:
- y_new -- Linearly interpolated values corresponding to x_new.
+ Returns
+ -------
+ y_new : number or array
+ Linearly interpolated value(s) corresponding to x_new.
"""
+
# 1. Handle values in x_new that are outside of x. Throw error,
# or return a list of mask array indicating the outofbounds values.
# The behavior is set by the bounds_error variable.
x_new = atleast_1d(x_new)
out_of_bounds = self._check_bounds(x_new)
+
# 2. Find where in the orignal data, the values to interpolate
# would be inserted.
- # Note: If x_new[n] = x[m], then m is returned by searchsorted.
- x_new_indices = searchsorted(self.x,x_new)
+ # Note: If x_new[n] == x[m], then m is returned by searchsorted.
+ x_new_indices = searchsorted(self.x, x_new)
+
# 3. Clip x_new_indices so that they are within the range of
# self.x indices and at least 1. Removes mis-interpolation
# of x_new[n] = x[0]
- x_new_indices = clip(x_new_indices,1,len(self.x)-1).astype(int)
+ x_new_indices = x_new_indices.clip(1, len(self.x)-1).astype(int)
+
# 4. Calculate the slope of regions that each x_new value falls in.
- lo = x_new_indices - 1; hi = x_new_indices
+ lo = x_new_indices - 1
+ hi = x_new_indices
- # !! take(,axis=0) should default to the last axis (IMHO) and remove
- # !! the extra argument.
- x_lo = take(self.x,lo,axis=self.interp_axis)
- x_hi = take(self.x,hi,axis=self.interp_axis);
- y_lo = take(self.y,lo,axis=self.interp_axis)
- y_hi = take(self.y,hi,axis=self.interp_axis);
- slope = (y_hi-y_lo)/(x_hi-x_lo)
+ x_lo = self.x[lo]
+ x_hi = self.x[hi]
+ y_lo = self.y[..., lo]
+ y_hi = self.y[..., hi]
+
+ # Note that the following two expressions rely on the specifics of the
+ # broadcasting semantics.
+ slope = (y_hi-y_lo) / (x_hi-x_lo)
+
# 5. Calculate the actual value for each entry in x_new.
y_new = slope*(x_new-x_lo) + y_lo
- # 6. Fill any values that were out of bounds with NaN
- # !! Need to think about how to do this efficiently for
- # !! mutli-dimensional Cases.
- yshape = y_new.shape
- y_new = y_new.ravel()
- new_shape = list(yshape)
- new_shape[self.interp_axis] = 1
- sec_shape = [1]*len(new_shape)
- sec_shape[self.interp_axis] = len(out_of_bounds)
- out_of_bounds.shape = sec_shape
- new_out = ones(new_shape)*out_of_bounds
- putmask(y_new, new_out.ravel(), self.fill_value)
- y_new.shape = yshape
- # Rotate the values of y_new back so that they coorespond to the
- # correct x_new values.
- result = swapaxes(y_new,self.interp_axis,self.axis)
- try:
- len(x_new)
- return result
- except TypeError:
- return result[0]
+
+ # 6. Fill any values that were out of bounds with fill_value.
+ y_new[..., out_of_bounds] = self.fill_value
+
+ # Rotate the values of y_new back so that they correspond to the
+ # correct x_new values. For N-D x_new, take the last N axes from y_new
+ # and insert them where self.axis was in the list of axes.
+ nx = len(x_new.shape)
+ ny = len(y_new.shape)
+ axes = range(ny - nx)
+ axes[self.axis:self.axis] = range(ny - nx, ny)
+ result = y_new.transpose(axes)
+
return result
- def _check_bounds(self,x_new):
- # If self.bounds_error = 1, we raise an error if any x_new values
+ def _check_bounds(self, x_new):
+ """ Check the inputs for being in the bounds of the interpolated data.
+
+ Parameters
+ ----------
+ x_new : array
+
+ Returns
+ -------
+ out_of_bounds : bool array
+ The mask on x_new of values that are out of the bounds.
+ """
+
+ # If self.bounds_error is True, we raise an error if any x_new values
# fall outside the range of x. Otherwise, we return an array indicating
# which values are outside the boundary region.
- # !! Needs some work for multi-dimensional x !!
- below_bounds = less(x_new,self.x[0])
- above_bounds = greater(x_new,self.x[-1])
- # Note: sometrue has been redefined to handle length 0 arrays
+ below_bounds = x_new < self.x[0]
+ above_bounds = x_new > self.x[-1]
+
# !! Could provide more information about which values are out of bounds
- if self.bounds_error and sometrue(below_bounds,axis=0):
- raise ValueError, " A value in x_new is below the"\
- " interpolation range."
- if self.bounds_error and sometrue(above_bounds,axis=0):
- raise ValueError, " A value in x_new is above the"\
- " interpolation range."
- # !! Should we emit a warning if some values are out of bounds.
+ if self.bounds_error and below_bounds.any():
+ raise ValueError("A value in x_new is below the interpolation "
+ "range.")
+ if self.bounds_error and above_bounds.any():
+ raise ValueError("A value in x_new is above the interpolation "
+ "range.")
+
+ # !! Should we emit a warning if some values are out of bounds?
# !! matlab does not.
- out_of_bounds = logical_or(below_bounds,above_bounds)
+ out_of_bounds = logical_or(below_bounds, above_bounds)
return out_of_bounds
- def model_error(self,x_new,y_new):
- # How well do x_new,yy points fit the model?
- # Return an array of error values.
- pass
-
-
-#assumes module test_xxx is in python path
-#def test():
-# test_module = 'test_' + __name__ # __name__ is name of this module
-# test_string = 'import %s;reload(%s);%s.test()' % ((test_module,)*3)
-# exec(test_string)
-
-#if __name__ == '__main__':
-# test()
Modified: trunk/Lib/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/Lib/interpolate/tests/test_interpolate.py 2006-09-23 21:56:21 UTC (rev 2219)
+++ trunk/Lib/interpolate/tests/test_interpolate.py 2006-09-23 22:50:16 UTC (rev 2220)
@@ -1,9 +1,12 @@
from numpy.testing import *
from numpy import mgrid, pi, sin, ogrid
+import numpy as np
+
set_package_path()
from interpolate import interp1d, interp2d
restore_path()
+
class test_interp2d(ScipyTestCase):
def test_interp2d(self):
y, x = mgrid[0:pi:20j, 0:pi:21j]
@@ -14,5 +17,182 @@
v,u = ogrid[0:pi:24j, 0:pi:25j]
assert_almost_equal(I(u.ravel(), v.ravel()), sin(v+u), decimal=2)
+
+class test_interp1d(ScipyTestCase):
+
+ def setUp(self):
+ self.x10 = np.arange(10.)
+ self.y10 = np.arange(10.)
+ self.x25 = self.x10.reshape((2,5))
+ self.x2 = np.arange(2.)
+ self.y2 = np.arange(2.)
+ self.x1 = np.array([0.])
+ self.y1 = np.array([0.])
+
+ self.y210 = np.arange(20.).reshape((2, 10))
+ self.y102 = np.arange(20.).reshape((10, 2))
+
+ self.fill_value = -100.0
+
+ def test_validation(self):
+ """ Make sure that appropriate exceptions are raised when invalid values
+ are given to the constructor.
+ """
+
+ # Only kind='linear' is implemented.
+ self.assertRaises(NotImplementedError, interp1d, self.x10, self.y10, kind='cubic')
+ interp1d(self.x10, self.y10, kind='linear')
+
+ # x array must be 1D.
+ self.assertRaises(ValueError, interp1d, self.x25, self.y10)
+
+ # y array cannot be a scalar.
+ self.assertRaises(ValueError, interp1d, self.x10, np.array(0))
+
+ # Check for x and y arrays having the same length.
+ self.assertRaises(ValueError, interp1d, self.x10, self.y2)
+ self.assertRaises(ValueError, interp1d, self.x2, self.y10)
+ self.assertRaises(ValueError, interp1d, self.x10, self.y102)
+ interp1d(self.x10, self.y210)
+ interp1d(self.x10, self.y102, axis=0)
+
+ # Check for x and y having at least 1 element.
+ self.assertRaises(ValueError, interp1d, self.x1, self.y10)
+ self.assertRaises(ValueError, interp1d, self.x10, self.y1)
+ self.assertRaises(ValueError, interp1d, self.x1, self.y1)
+
+
+ def test_init(self):
+ """ Check that the attributes are initialized appropriately by the
+ constructor.
+ """
+
+ self.assert_(interp1d(self.x10, self.y10).copy)
+ self.assert_(not interp1d(self.x10, self.y10, copy=False).copy)
+ self.assert_(interp1d(self.x10, self.y10).bounds_error)
+ self.assert_(not interp1d(self.x10, self.y10, bounds_error=False).bounds_error)
+ self.assert_(np.isnan(interp1d(self.x10, self.y10).fill_value))
+ self.assertEqual(
+ interp1d(self.x10, self.y10, fill_value=3.0).fill_value,
+ 3.0,
+ )
+ self.assertEqual(
+ interp1d(self.x10, self.y10).axis,
+ 0,
+ )
+ self.assertEqual(
+ interp1d(self.x10, self.y210).axis,
+ 1,
+ )
+ self.assertEqual(
+ interp1d(self.x10, self.y102, axis=0).axis,
+ 0,
+ )
+ assert_array_equal(
+ interp1d(self.x10, self.y10).x,
+ self.x10,
+ )
+ assert_array_equal(
+ interp1d(self.x10, self.y10).y,
+ self.y10,
+ )
+ assert_array_equal(
+ interp1d(self.x10, self.y210).y,
+ self.y210,
+ )
+
+
+ def test_linear(self):
+ """ Check the actual implementation of linear interpolation.
+ """
+
+ interp10 = interp1d(self.x10, self.y10)
+ assert_array_almost_equal(
+ interp10(self.x10),
+ self.y10,
+ )
+ assert_array_almost_equal(
+ interp10(1.2),
+ np.array([1.2]),
+ )
+ assert_array_almost_equal(
+ interp10([2.4, 5.6, 6.0]),
+ np.array([2.4, 5.6, 6.0]),
+ )
+
+
+ def test_bounds(self):
+ """ Test that our handling of out-of-bounds input is correct.
+ """
+
+ extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
+ bounds_error=False)
+ assert_array_equal(
+ extrap10(11.2),
+ np.array([self.fill_value]),
+ )
+ assert_array_equal(
+ extrap10(-3.4),
+ np.array([self.fill_value]),
+ )
+ assert_array_equal(
+ extrap10._check_bounds(np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
+ np.array([True, False, False, False, True]),
+ )
+
+ raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True)
+ self.assertRaises(ValueError, raises_bounds_error, -1.0)
+ self.assertRaises(ValueError, raises_bounds_error, 11.0)
+ raises_bounds_error([0.0, 5.0, 9.0])
+
+
+ def test_nd(self):
+ """ Check the behavior when the inputs and outputs are multidimensional.
+ """
+
+ # Multidimensional input.
+ interp10 = interp1d(self.x10, self.y10)
+ assert_array_almost_equal(
+ interp10(np.array([[3.4, 5.6], [2.4, 7.8]])),
+ np.array([[3.4, 5.6], [2.4, 7.8]]),
+ )
+
+ # Multidimensional outputs.
+ interp210 = interp1d(self.x10, self.y210)
+ assert_array_almost_equal(
+ interp210(1.5),
+ np.array([[1.5], [11.5]]),
+ )
+ assert_array_almost_equal(
+ interp210(np.array([1.5, 2.4])),
+ np.array([[1.5, 2.4],
+ [11.5, 12.4]]),
+ )
+
+ interp102 = interp1d(self.x10, self.y102, axis=0)
+ assert_array_almost_equal(
+ interp102(1.5),
+ np.array([[3.0, 4.0]]),
+ )
+ assert_array_almost_equal(
+ interp102(np.array([1.5, 2.4])),
+ np.array([[3.0, 4.0],
+ [4.8, 5.8]]),
+ )
+
+ # Both at the same time!
+ x_new = np.array([[3.4, 5.6], [2.4, 7.8]])
+ assert_array_almost_equal(
+ interp210(x_new),
+ np.array([[[3.4, 5.6], [2.4, 7.8]],
+ [[13.4, 15.6], [12.4, 17.8]]]),
+ )
+ assert_array_almost_equal(
+ interp102(x_new),
+ np.array([[[6.8, 7.8], [11.2, 12.2]],
+ [[4.8, 5.8], [15.6, 16.6]]]),
+ )
+
+
if __name__ == "__main__":
ScipyTest().run()
More information about the Scipy-svn
mailing list