[Scipy-svn] r6905 - in trunk/scipy/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Nov 17 15:49:13 EST 2010
Author: ptvirtan
Date: 2010-11-17 14:49:13 -0600 (Wed, 17 Nov 2010)
New Revision: 6905
Modified:
trunk/scipy/interpolate/interpnd.pyx
trunk/scipy/interpolate/ndgriddata.py
trunk/scipy/interpolate/tests/test_interpnd.py
trunk/scipy/interpolate/tests/test_ndgriddata.py
Log:
ENH: interpolate: allow more natural __call__ usage in N-D interpolants
Modified: trunk/scipy/interpolate/interpnd.pyx
===================================================================
--- trunk/scipy/interpolate/interpnd.pyx 2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/interpnd.pyx 2010-11-17 20:49:13 UTC (rev 6905)
@@ -109,7 +109,7 @@
raise ValueError("number of dimensions in xi does not match x")
return xi
- def __call__(self, xi):
+ def __call__(self, *args):
"""
interpolator(xi)
@@ -121,7 +121,7 @@
Points where to interpolate data at.
"""
- xi = _ndim_coords_from_arrays(xi)
+ xi = _ndim_coords_from_arrays(args)
xi = self._check_call_shape(xi)
xi = np.ascontiguousarray(xi.astype(np.double))
shape = xi.shape
@@ -139,8 +139,10 @@
Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.
"""
- if (isinstance(points, tuple) or isinstance(points, list)) \
- and points and isinstance(points[0], np.ndarray):
+ if isinstance(points, tuple) and len(points) == 1:
+ # handle argument tuple
+ points = points[0]
+ if isinstance(points, tuple):
p = np.broadcast_arrays(*points)
for j in xrange(1, len(p)):
if p[j].shape != p[0].shape:
@@ -150,6 +152,8 @@
points[...,j] = item
else:
points = np.asanyarray(points)
+ if points.ndim == 1:
+ points = points.reshape(-1, 1)
return points
#------------------------------------------------------------------------------
@@ -189,7 +193,7 @@
def __init__(self, points, values, fill_value=np.nan):
NDInterpolatorBase.__init__(self, points, values, fill_value=fill_value)
- self.tri = qhull.Delaunay(points)
+ self.tri = qhull.Delaunay(self.points)
% for DTYPE, CDTYPE in zip(["double", "complex"], ["double", "double complex"]):
@cython.boundscheck(False)
@@ -798,7 +802,7 @@
tol=1e-6, maxiter=400):
NDInterpolatorBase.__init__(self, points, values, ndim=2,
fill_value=fill_value)
- self.tri = qhull.Delaunay(points)
+ self.tri = qhull.Delaunay(self.points)
self.grad = estimate_gradients_2d_global(self.tri, self.values,
tol=tol, maxiter=maxiter)
Modified: trunk/scipy/interpolate/ndgriddata.py
===================================================================
--- trunk/scipy/interpolate/ndgriddata.py 2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/ndgriddata.py 2010-11-17 20:49:13 UTC (rev 6905)
@@ -45,7 +45,7 @@
self.points = x
self.values = y
- def __call__(self, xi):
+ def __call__(self, *args):
"""
Evaluate interpolator at given points.
@@ -55,6 +55,7 @@
Points where to interpolate data at.
"""
+ xi = _ndim_coords_from_arrays(args)
xi = self._check_call_shape(xi)
dist, i = self.tree.query(xi)
return self.values[i]
@@ -164,8 +165,9 @@
if ndim == 1 and method in ('nearest', 'linear', 'cubic'):
from interpolate import interp1d
points = points.ravel()
- if (isinstance(xi, tuple) or isinstance(xi, list)) \
- and xi and isinstance(xi[0], np.ndarray):
+ if isinstance(xi, tuple):
+ if len(xi) != 1:
+ raise ValueError("invalid number of dimensions in xi")
xi, = xi
ip = interp1d(points, values, kind=method, axis=0, bounds_error=False,
fill_value=fill_value)
Modified: trunk/scipy/interpolate/tests/test_interpnd.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpnd.py 2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/tests/test_interpnd.py 2010-11-17 20:49:13 UTC (rev 6905)
@@ -15,6 +15,15 @@
yi = interpnd.LinearNDInterpolator(x, y)(x)
assert_almost_equal(y, yi)
+ def test_smoketest_alternate(self):
+ # Test at single points, alternate calling convention
+ x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
+ dtype=np.double)
+ y = np.arange(x.shape[0], dtype=np.double)
+
+ yi = interpnd.LinearNDInterpolator((x[:,0], x[:,1]), y)(x[:,0], x[:,1])
+ assert_almost_equal(y, yi)
+
def test_complex_smoketest(self):
# Test at single points
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
@@ -95,7 +104,7 @@
class TestCloughTocher2DInterpolator(object):
- def _check_accuracy(self, func, x=None, tol=1e-6, **kw):
+ def _check_accuracy(self, func, x=None, tol=1e-6, alternate=False, **kw):
np.random.seed(1234)
if x is None:
x = np.array([(0, 0), (0, 1),
@@ -103,11 +112,20 @@
(0.5, 0.2)],
dtype=float)
- ip = interpnd.CloughTocher2DInterpolator(x, func(x[:,0], x[:,1]),
- tol=1e-6)
+ if not alternate:
+ ip = interpnd.CloughTocher2DInterpolator(x, func(x[:,0], x[:,1]),
+ tol=1e-6)
+ else:
+ ip = interpnd.CloughTocher2DInterpolator((x[:,0], x[:,1]),
+ func(x[:,0], x[:,1]),
+ tol=1e-6)
+
p = np.random.rand(50, 2)
- a = ip(p)
+ if not alternate:
+ a = ip(p)
+ else:
+ a = ip(p[:,0], p[:,1])
b = func(p[:,0], p[:,1])
try:
@@ -129,6 +147,9 @@
for j, func in enumerate(funcs):
self._check_accuracy(func, tol=1e-13, atol=1e-7, rtol=1e-7,
err_msg="Function %d" % j)
+ self._check_accuracy(func, tol=1e-13, atol=1e-7, rtol=1e-7,
+ alternate=True,
+ err_msg="Function (alternate) %d" % j)
def test_quadratic_smoketest(self):
# Should be reasonably accurate for quadratic functions
Modified: trunk/scipy/interpolate/tests/test_ndgriddata.py
===================================================================
--- trunk/scipy/interpolate/tests/test_ndgriddata.py 2010-11-17 20:48:53 UTC (rev 6904)
+++ trunk/scipy/interpolate/tests/test_ndgriddata.py 2010-11-17 20:49:13 UTC (rev 6905)
@@ -11,7 +11,7 @@
y = [1, 2, 3]
yi = griddata(x, y, [(1,1), (1,2), (0,0)], fill_value=-1)
- assert_array_equal(yi, [-1, -1, 1])
+ assert_array_equal(yi, [-1., -1, 1])
yi = griddata(x, y, [(1,1), (1,2), (0,0)])
assert_array_equal(yi, [np.nan, np.nan, 1])
More information about the Scipy-svn
mailing list