[Scipy-svn] r6905 - in trunk/scipy/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Wed Nov 17 15:49:13 EST 2010


Author: ptvirtan
Date: 2010-11-17 14:49:13 -0600 (Wed, 17 Nov 2010)
New Revision: 6905

Modified:
   trunk/scipy/interpolate/interpnd.pyx
   trunk/scipy/interpolate/ndgriddata.py
   trunk/scipy/interpolate/tests/test_interpnd.py
   trunk/scipy/interpolate/tests/test_ndgriddata.py
Log:
ENH: interpolate: allow more natural __call__ usage in N-D interpolants

Modified: trunk/scipy/interpolate/interpnd.pyx
===================================================================
--- trunk/scipy/interpolate/interpnd.pyx	2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/interpnd.pyx	2010-11-17 20:49:13 UTC (rev 6905)
@@ -109,7 +109,7 @@
             raise ValueError("number of dimensions in xi does not match x")
         return xi
 
-    def __call__(self, xi):
+    def __call__(self, *args):
         """
         interpolator(xi)
 
@@ -121,7 +121,7 @@
             Points where to interpolate data at.
 
         """
-        xi = _ndim_coords_from_arrays(xi)
+        xi = _ndim_coords_from_arrays(args)
         xi = self._check_call_shape(xi)
         xi = np.ascontiguousarray(xi.astype(np.double))
         shape = xi.shape
@@ -139,8 +139,10 @@
     Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.
 
     """
-    if (isinstance(points, tuple) or isinstance(points, list)) \
-           and points and isinstance(points[0], np.ndarray):
+    if isinstance(points, tuple) and len(points) == 1:
+        # handle argument tuple
+        points = points[0]
+    if isinstance(points, tuple):
         p = np.broadcast_arrays(*points)
         for j in xrange(1, len(p)):
             if p[j].shape != p[0].shape:
@@ -150,6 +152,8 @@
             points[...,j] = item
     else:
         points = np.asanyarray(points)
+        if points.ndim == 1:
+            points = points.reshape(-1, 1)
     return points
 
 #------------------------------------------------------------------------------
@@ -189,7 +193,7 @@
 
     def __init__(self, points, values, fill_value=np.nan):
         NDInterpolatorBase.__init__(self, points, values, fill_value=fill_value)
-        self.tri = qhull.Delaunay(points)
+        self.tri = qhull.Delaunay(self.points)
 
 % for DTYPE, CDTYPE in zip(["double", "complex"], ["double", "double complex"]):
     @cython.boundscheck(False)
@@ -798,7 +802,7 @@
                  tol=1e-6, maxiter=400):
         NDInterpolatorBase.__init__(self, points, values, ndim=2,
                                     fill_value=fill_value)
-        self.tri = qhull.Delaunay(points)
+        self.tri = qhull.Delaunay(self.points)
         self.grad = estimate_gradients_2d_global(self.tri, self.values,
                                                  tol=tol, maxiter=maxiter)
 

Modified: trunk/scipy/interpolate/ndgriddata.py
===================================================================
--- trunk/scipy/interpolate/ndgriddata.py	2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/ndgriddata.py	2010-11-17 20:49:13 UTC (rev 6905)
@@ -45,7 +45,7 @@
         self.points = x
         self.values = y
 
-    def __call__(self, xi):
+    def __call__(self, *args):
         """
         Evaluate interpolator at given points.
 
@@ -55,6 +55,7 @@
             Points where to interpolate data at.
 
         """
+        xi = _ndim_coords_from_arrays(args)
         xi = self._check_call_shape(xi)
         dist, i = self.tree.query(xi)
         return self.values[i]
@@ -164,8 +165,9 @@
     if ndim == 1 and method in ('nearest', 'linear', 'cubic'):
         from interpolate import interp1d
         points = points.ravel()
-        if (isinstance(xi, tuple) or isinstance(xi, list)) \
-               and xi and isinstance(xi[0], np.ndarray):
+        if isinstance(xi, tuple):
+            if len(xi) != 1:
+                raise ValueError("invalid number of dimensions in xi")
             xi, = xi
         ip = interp1d(points, values, kind=method, axis=0, bounds_error=False,
                       fill_value=fill_value)

Modified: trunk/scipy/interpolate/tests/test_interpnd.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpnd.py	2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/tests/test_interpnd.py	2010-11-17 20:49:13 UTC (rev 6905)
@@ -15,6 +15,15 @@
         yi = interpnd.LinearNDInterpolator(x, y)(x)
         assert_almost_equal(y, yi)
 
+    def test_smoketest_alternate(self):
+        # Test at single points, alternate calling convention
+        x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
+                     dtype=np.double)
+        y = np.arange(x.shape[0], dtype=np.double)
+
+        yi = interpnd.LinearNDInterpolator((x[:,0], x[:,1]), y)(x[:,0], x[:,1])
+        assert_almost_equal(y, yi)
+
     def test_complex_smoketest(self):
         # Test at single points
         x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
@@ -95,7 +104,7 @@
 
 class TestCloughTocher2DInterpolator(object):
 
-    def _check_accuracy(self, func, x=None, tol=1e-6, **kw):
+    def _check_accuracy(self, func, x=None, tol=1e-6, alternate=False, **kw):
         np.random.seed(1234)
         if x is None:
             x = np.array([(0, 0), (0, 1),
@@ -103,11 +112,20 @@
                           (0.5, 0.2)],
                          dtype=float)
 
-        ip = interpnd.CloughTocher2DInterpolator(x, func(x[:,0], x[:,1]),
-                                                 tol=1e-6)
+        if not alternate:
+            ip = interpnd.CloughTocher2DInterpolator(x, func(x[:,0], x[:,1]),
+                                                     tol=1e-6)
+        else:
+            ip = interpnd.CloughTocher2DInterpolator((x[:,0], x[:,1]),
+                                                     func(x[:,0], x[:,1]),
+                                                     tol=1e-6)
+
         p = np.random.rand(50, 2)
 
-        a = ip(p)
+        if not alternate:
+            a = ip(p)
+        else:
+            a = ip(p[:,0], p[:,1])
         b = func(p[:,0], p[:,1])
 
         try:
@@ -129,6 +147,9 @@
         for j, func in enumerate(funcs):
             self._check_accuracy(func, tol=1e-13, atol=1e-7, rtol=1e-7,
                                  err_msg="Function %d" % j)
+            self._check_accuracy(func, tol=1e-13, atol=1e-7, rtol=1e-7,
+                                 alternate=True,
+                                 err_msg="Function (alternate) %d" % j)
 
     def test_quadratic_smoketest(self):
         # Should be reasonably accurate for quadratic functions

Modified: trunk/scipy/interpolate/tests/test_ndgriddata.py
===================================================================
--- trunk/scipy/interpolate/tests/test_ndgriddata.py	2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/tests/test_ndgriddata.py	2010-11-17 20:49:13 UTC (rev 6905)
@@ -11,7 +11,7 @@
         y = [1, 2, 3]
 
         yi = griddata(x, y, [(1,1), (1,2), (0,0)], fill_value=-1)
-        assert_array_equal(yi, [-1, -1, 1])
+        assert_array_equal(yi, [-1., -1, 1])
 
         yi = griddata(x, y, [(1,1), (1,2), (0,0)])
         assert_array_equal(yi, [np.nan, np.nan, 1])




More information about the Scipy-svn mailing list