[Scipy-svn] r3000 - trunk/Lib/interpolate

scipy-svn at scipy.org scipy-svn at scipy.org
Mon May 14 22:56:35 EDT 2007


Author: oliphant
Date: 2007-05-14 21:56:16 -0500 (Mon, 14 May 2007)
New Revision: 3000

Modified:
   trunk/Lib/interpolate/interpolate.py
Log:
Change spline names to allow for adding quadratic, quartic, and quintic splines.

Modified: trunk/Lib/interpolate/interpolate.py
===================================================================
--- trunk/Lib/interpolate/interpolate.py	2007-05-15 00:17:29 UTC (rev 2999)
+++ trunk/Lib/interpolate/interpolate.py	2007-05-15 02:56:16 UTC (rev 3000)
@@ -1,12 +1,9 @@
 """ Classes for interpolating values.
 """
 
-# !! Need to find argument for keeping initialize.  If it isn't
-# !! found, get rid of it!
+__all__ = ['interp1d', 'interp2d', 'spline3', 'sp3eval', 'sp3rep', 'sp3topp',
+           'ppform']
 
-__all__ = ['interp1d', 'interp2d', 'cspline', 'cspeval', 'csprep', 'csp2pp',
-           'ppval']
-
 from numpy import shape, sometrue, rank, array, transpose, \
      swapaxes, searchsorted, clip, take, ones, putmask, less, greater, \
      logical_or, atleast_1d, atleast_2d, meshgrid, ravel
@@ -20,6 +17,9 @@
         all = sometrue(all,axis=0)
     return all
 
+# !! Need to find argument for keeping initialize.  If it isn't
+# !! found, get rid of it!
+
 class interp2d(object):
     """ Interpolate over a 2D grid.
 
@@ -151,8 +151,8 @@
             An N-D array of real values.  y's length along the interpolation
             axis must be equal to the length of x.
         kind : str
-            Specifies the kind of interpolation. At the moment, only 'linear' is
-            implemented.
+            Specifies the kind of interpolation. At the moment,
+            only 'linear' and 'cubic' are implemented.
         axis : int
             Specifies the axis of y along which to interpolate. Interpolation
             defaults to the last axis of y.
@@ -176,54 +176,48 @@
         self.bounds_error = bounds_error
         self.fill_value = fill_value
 
-        if kind != 'linear':
-            raise NotImplementedError("Only linear supported for now. Use "
-                "fitpack routines for other types.")
-
+        if kind not in ['linear', 'cubic']:
+            raise NotImplementedError("Only linear and cubic supported for " \
+                                      "now. Use fitpack routines "\
+                                      "for other types.")
         x = array(x, copy=self.copy)
         y = array(y, copy=self.copy)
 
-        if len(x.shape) != 1:
+        if x.ndim != 1:
             raise ValueError("the x array must have exactly one dimension.")
-        if len(y.shape) == 0:
+        if y.ndim == 0:
             raise ValueError("the y array must have at least one dimension.")
 
         # Normalize the axis to ensure that it is positive.
         self.axis = axis % len(y.shape)
-
-        # Make a "view" of the y array that is rotated to the interpolation
-        # axis.
-        oriented_y = y.swapaxes(self._interp_axis, axis)
+        self._kind = kind
+        
+        if kind == 'linear':
+            # Make a "view" of the y array that is rotated to the interpolation
+            # axis.
+            oriented_y = y.swapaxes(self._interp_axis, axis)
+            minval = 2
+            len_y = oriented_y.shape[self._interp_axis]
+            self._call = self._call_linear
+        else:
+            oriented_y = y.swapaxes(0, axis)
+            minval = 4
+            len_y = oriented_y.shape[0]
+            self._call = self._call_cubic
+            self._spline = sp3rep(x,oriented_y)
+        
         len_x = len(x)
-        len_y = oriented_y.shape[self._interp_axis]
         if len_x != len_y:
-            raise ValueError("x and y arrays must be equal in length along"
+                raise ValueError("x and y arrays must be equal in length along"
                 "interpolation axis.")
-        if len_x < 2 or len_y < 2:
-            raise ValueError("x and y arrays must have more than 1 entry")
+        if len_x < minval:
+            raise ValueError("x and y arrays must have at " \
+                             "least %d entries" % minval)
         self.x = x
         self.y = oriented_y
 
-    def __call__(self, x_new):
-        """ Find linearly interpolated y_new = f(x_new).
+    def _call_linear(self, x_new):
 
-        Parameters
-        ----------
-        x_new : number or array
-            New independent variable(s).
-
-        Returns
-        -------
-        y_new : number or array
-            Linearly interpolated value(s) corresponding to x_new.
-        """
-
-        # 1. Handle values in x_new that are outside of x.  Throw error,
-        #    or return a list of mask array indicating the outofbounds values.
-        #    The behavior is set by the bounds_error variable.
-        x_new = atleast_1d(x_new)
-        out_of_bounds = self._check_bounds(x_new)
-
         # 2. Find where in the orignal data, the values to interpolate
         #    would be inserted.
         #    Note: If x_new[n] == x[m], then m is returned by searchsorted.
@@ -250,19 +244,49 @@
         # 5. Calculate the actual value for each entry in x_new.
         y_new = slope*(x_new-x_lo) + y_lo
 
-        # 6. Fill any values that were out of bounds with fill_value.
-        y_new[..., out_of_bounds] = self.fill_value
+        return y_new
 
+    def _call_cubic(self, x_new):
+        return sp3eval(self._spline,x_new)
+
+    def __call__(self, x_new):
+        """ Find linearly interpolated y_new = f(x_new).
+
+        Parameters
+        ----------
+        x_new : number or array
+            New independent variable(s).
+
+        Returns
+        -------
+        y_new : number or array
+            Linearly interpolated value(s) corresponding to x_new.
+        """        
+
+        # 1. Handle values in x_new that are outside of x.  Throw error,
+        #    or return a list of mask array indicating the outofbounds values.
+        #    The behavior is set by the bounds_error variable.
+        x_new = atleast_1d(x_new)
+        out_of_bounds = self._check_bounds(x_new)
+
+        y_new = self._call(x_new)
+
+        if self._kind == 'linear':
+            # 6. Fill any values that were out of bounds with fill_value.
+            y_new[..., out_of_bounds] = self.fill_value
+        else:
+            y_new[out_of_bounds] = self.fill_value
+
         # Rotate the values of y_new back so that they correspond to the
         # correct x_new values. For N-D x_new, take the last N axes from y_new
         # and insert them where self.axis was in the list of axes.
-        nx = len(x_new.shape)
-        ny = len(y_new.shape)
+        nx = x_new.ndim
+        ny = y_new.ndim
         axes = range(ny - nx)
         axes[self.axis:self.axis] = range(ny - nx, ny)
         result = y_new.transpose(axes)
-
         return result
+                
 
     def _check_bounds(self, x_new):
         """ Check the inputs for being in the bounds of the interpolated data.
@@ -296,8 +320,24 @@
         out_of_bounds = logical_or(below_bounds, above_bounds)
         return out_of_bounds
 
+class ppform(object):
+    def __init__(self, coeffs, breaks, dosort=False):
+        self.coeffs = np.asarray(coeffs)
+        if dosort:
+            self.breaks = np.sort(breaks)
+        else:
+            self.breaks = np.asarray(breaks)
+        self.N = self.coeffs.shape[0]
+    def __call__(self, xnew):
+        indxs = np.searchsorted(self.breaks, xnew)-1
+        indxs[indxs<0]=0
+        pp = self.coeffs
+        V = np.vander(xnew,N=self.N)
+        # res = np.diag(np.dot(V,pp[:,indxs]))
+        res = array([np.dot(V[k,:],pp[:,indxs[k]]) for k in xrange(len(xnew))])
+        return res
 
-def _get_cspline_Bb(xk, yk, kind, conds):
+def _get_spline3_Bb(xk, yk, kind, conds):
     # internal function to compute different tri-diagonal system
     # depending on the kind of spline requested.
     # conds is only used for 'second' and 'first' 
@@ -365,7 +405,7 @@
             B[-1,-3:] = [dN,-dN1-dN,dN1]
         elif kind == 'runout':
             B[0,:3] = [1,-2,1]
-            b[-1,-3:] = [1,-2,1]
+            B[-1,-3:] = [1,-2,1]
         elif kind == 'parabolic':
             B[0,:2] = [1,-1]
             B[-1,-2:] = [-1,1]
@@ -396,7 +436,7 @@
         raise ValueError, "%s not supported" % kind
         
 
-def cspeval((mk,xk,yk),xnew):
+def sp3eval((mk,xk,yk),xnew):
     """Evaluate a cubic-spline representation of the points (xk,yk)
     at the new values xnew.  The mk values are the second derivatives at xk
     The xk vector must be sorted.
@@ -420,7 +460,7 @@
     val += (yk[indxs]/dk - mk0*dk/6.)*dm1
     return val
 
-def csp2pp(mk,xk,yk):
+def sp3topp(mk,xk,yk):
     """Return an N-d array providing the piece-wise polynomial form.
 
     mk - second derivative at the knots
@@ -442,17 +482,9 @@
     c0 = (mk[:-1]*xk[1:]**3 - mk[1:]*xk[:-1]**3)/(6*dk)
     c0 += temp2*dk/6.
     c0 += (yk[:-1]*xk[1:] - yk[1:]*xk[:-1])/dk
-    return np.array([c3,c2,c1,c0])    
+    return ppform([c3,c2,c1,c0], xk)
 
-def ppval(pp, xk, xnew):
-    """Compute a piece-wise polynomial defined by the array of
-    coefficents pp and the break-points xk on the grid xnew
-    """
-    indxs = numpy.searchsorted(xk, xnew)-1
-    indxs[indxs<0]=0
-    return array([numpy.polyval(pp[:,k],xnew[i]) for i,k in enumerate(indxs)])
-
-def csprep(xk,yk,kind='not-a-knot',conds=None):
+def sp3rep(xk,yk,kind='not-a-knot',conds=None):
     """Return a (Spp,xk,yk) representation of a cubic spline given
     data-points
 
@@ -469,7 +501,7 @@
     """
     yk = np.asanyarray(yk)
     N = yk.shape[0]-1
-    B,b,first,last = _get_cspline_Bb(xk, yk, kind, conds)
+    B,b,first,last = _get_spline3_Bb(xk, yk, kind, conds)
     mk = np.dual.solve(B,b)
     if first is not None:
         mk = np.concatenate((first, mk), axis=0)
@@ -477,5 +509,13 @@
         mk = np.concatenate((mk, last), axis=0)
     return mk, xk, yk
 
-def cspline(xk,yk,xnew,kind='not-a-knot',conds=None):
-    return cspeval(csprep(xk,yk,kind=kind,conds=conds),xnew)
+def spline3(xk,yk,xnew,kind='not-a-knot',conds=None):
+    return sp3eval(sp3rep(xk,yk,kind=kind,conds=conds),xnew)
+
+def sp2rep(xk,yk):
+    pass
+
+def sp2eval(xk,yk):
+    pass
+
+def 




More information about the Scipy-svn mailing list