[Scipy-svn] r5275 - in trunk/scipy/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Thu Dec 18 15:32:27 EST 2008


Author: ptvirtan
Date: 2008-12-18 14:32:14 -0600 (Thu, 18 Dec 2008)
New Revision: 5275

Modified:
   trunk/scipy/interpolate/interpolate.py
   trunk/scipy/interpolate/tests/test_interpolate.py
Log:
interp1d/spleval: enable spline interpolation of complex-valued data, by evaluating real and imaginary parts separately

Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py	2008-12-18 02:51:45 UTC (rev 5274)
+++ trunk/scipy/interpolate/interpolate.py	2008-12-18 20:32:14 UTC (rev 5275)
@@ -248,9 +248,6 @@
             self._call = self._call_spline
             self._spline = splmake(x,oriented_y,order=order)
 
-            if issubclass(y.dtype.type, np.complexfloating):
-                raise ValueError("Input data must be real for spline interpolation")
-
         len_x = len(x)
         if len_x != len_y:
             raise ValueError("x and y arrays must be equal in length along "
@@ -765,10 +762,14 @@
     oldshape = np.shape(xnew)
     xx = np.ravel(xnew)
     sh = cvals.shape[1:]
-    res = np.empty(xx.shape + sh)
+    res = np.empty(xx.shape + sh, dtype=cvals.dtype)
     for index in np.ndindex(*sh):
         sl = (slice(None),)+index
-        res[sl] = _fitpack._bspleval(xx,xj,cvals[sl],k,deriv)
+        if issubclass(cvals.dtype.type, np.complexfloating):
+            res[sl].real = _fitpack._bspleval(xx,xj,cvals.real[sl],k,deriv)
+            res[sl].imag = _fitpack._bspleval(xx,xj,cvals.imag[sl],k,deriv)
+        else:
+            res[sl] = _fitpack._bspleval(xx,xj,cvals[sl],k,deriv)
     res.shape = oldshape + sh
     return res
 

Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py	2008-12-18 02:51:45 UTC (rev 5274)
+++ trunk/scipy/interpolate/tests/test_interpolate.py	2008-12-18 20:32:14 UTC (rev 5275)
@@ -292,35 +292,28 @@
             yield self._nd_check_interp, kind
             yield self._nd_check_shape, kind
 
-    def _check_complex(self, dtype=np.complex_, kind='linear', fail=False):
-        x = np.arange(10).astype(np.int_)
-        y = np.arange(10).astype(np.int_) * (1 + 2j)
+    def _check_complex(self, dtype=np.complex_, kind='linear'):
+        x = np.array([1, 2.5, 3, 3.1, 4, 6.4, 7.9, 8.0, 9.5, 10])
+        y = x * x ** (1 + 2j)
         y = y.astype(dtype)
-        if fail:
-            assert_raises(ValueError, interp1d, x, y, kind=kind)
-        else:
-            c = interp1d(x, y, kind=kind)
-            assert_array_almost_equal(y[:-1], c(x)[:-1])
-        assert (y.dtype == dtype) or not issubclass(y.dtype.type, np.inexact)
 
+        # simple test
+        c = interp1d(x, y, kind=kind)
+        assert_array_almost_equal(y[:-1], c(x)[:-1])
+
+        # check against interpolating real+imag separately
+        xi = np.linspace(1, 10, 31)
+        cr = interp1d(x, y.real, kind=kind)
+        ci = interp1d(x, y.imag, kind=kind)
+        assert_array_almost_equal(c(xi).real, cr(xi))
+        assert_array_almost_equal(c(xi).imag, ci(xi))
+
     def test_complex(self):
-        for kind in ('linear', 'nearest'):
+        for kind in ('linear', 'nearest', 'cubic', 'slinear', 'quadratic',
+                     'zero'):
             yield self._check_complex, np.complex64, kind
             yield self._check_complex, np.complex128, kind
-            yield self._check_complex, np.float32, kind
-            yield self._check_complex, np.float64, kind
 
-        # The spline methods can't handle complex values, because the code
-        # for _fitpack.bispleval is written in C and is not type-agnostic.
-        #
-        # Check that a ValueError is raised if one attempts to interpolate
-        # complex data using these routines.
-        for kind in ('cubic', 'slinear', 'quadratic', 'zero'):
-            yield self._check_complex, np.complex64, kind, True
-            yield self._check_complex, np.complex128, kind, True
-            yield self._check_complex, np.float32, kind
-            yield self._check_complex, np.float64, kind
-
     @dec.knownfailureif(True, "zero-order splines fail for the last point")
     def test_nd_zero_spline(self):
         # zero-order splines don't get the last point right,




More information about the Scipy-svn mailing list