[Scipy-svn] r4195 - in trunk/scipy/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Apr 28 00:59:05 EDT 2008
Author: peridot
Date: 2008-04-27 23:59:03 -0500 (Sun, 27 Apr 2008)
New Revision: 4195
Modified:
trunk/scipy/interpolate/polyint.py
trunk/scipy/interpolate/tests/test_polyint.py
Log:
Fix bug introduced in r4181; also PiecewisePolynomial now correctly distinguishes between scalar values and vectors of length 1.
Modified: trunk/scipy/interpolate/polyint.py
===================================================================
--- trunk/scipy/interpolate/polyint.py 2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/polyint.py 2008-04-28 04:59:03 UTC (rev 4195)
@@ -432,16 +432,20 @@
derivatives needed is odd, it will prefer the rightmost endpoint. If
not enough derivatives are available, an exception is raised.
"""
+ yi0 = np.asarray(yi[0])
+ if len(yi0.shape)==2:
+ self.vector_valued = True
+ self.r = yi0.shape[1]
+ elif len(yi0.shape)==1:
+ self.vector_valued = False
+ self.r = 1
+ else:
+ raise ValueError, "Each derivative must be a vector, not a higher-rank array"
+
self.xi = [xi[0]]
- self.yi = [yi[0]]
+ self.yi = [yi0]
self.n = 1
- try:
- self.r = len(yi[0][0])
- except TypeError:
- self.r = 1
-
- self.n = 1
self.direction = direction
self.orders = []
self.polynomials = []
@@ -468,7 +472,11 @@
assert n2<=len(y2)
xi = np.zeros(n)
- yi = np.zeros((n,self.r))
+ if self.vector_valued:
+ yi = np.zeros((n,self.r))
+ else:
+ yi = np.zeros((n,))
+
xi[:n1] = x1
yi[:n1] = y1[:n1]
xi[n1:] = x2
@@ -488,19 +496,23 @@
a polynomial order, or instructions to use the highest
possible order
"""
+
+ yi = np.asarray(yi)
+ if self.vector_valued:
+ if (len(yi.shape)!=2 or yi.shape[1]!=self.r):
+ raise ValueError, "Each derivative must be a vector of length %d" % self.r
+ else:
+ if len(yi.shape)!=1:
+ raise ValueError, "Each derivative must be a scalar"
+
if self.direction is None:
self.direction = np.sign(xi-self.xi[-1])
elif (xi-self.xi[-1])*self.direction < 0:
raise ValueError, "x coordinates must be in the %d direction: %s" % (self.direction, self.xi)
+
self.xi.append(xi)
self.yi.append(yi)
- for y in yi:
- if np.shape(y) != (self.r,):
- if self.r>1:
- raise ValueError, "Each derivative must be a vector of length %d" % self.r
- else:
- raise ValueError, "Each derivative must be a scalar"
if order is None:
n1 = len(self.yi[-2])
@@ -558,7 +570,7 @@
x = np.asarray(x)
m = len(x)
pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
- if self.r>1:
+ if self.vector_valued:
y = np.zeros((m,self.r))
else:
y = np.zeros(m)
@@ -611,7 +623,7 @@
x = np.asarray(x)
m = len(x)
pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
- if self.r>1:
+ if self.vector_valued:
y = np.zeros((der,m,self.r))
else:
y = np.zeros((der,m))
Modified: trunk/scipy/interpolate/tests/test_polyint.py
===================================================================
--- trunk/scipy/interpolate/tests/test_polyint.py 2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/tests/test_polyint.py 2008-04-28 04:59:03 UTC (rev 4195)
@@ -219,6 +219,13 @@
assert_array_equal(np.shape(P([0])), (1,3))
assert_array_equal(np.shape(P([0,1])), (2,3))
+ def test_shapes_vectorvalue_1d(self):
+ yi = np.multiply.outer(np.asarray(self.yi),np.arange(1))
+ P = PiecewisePolynomial(self.xi,yi,4)
+ assert_array_equal(np.shape(P(0)), (1,))
+ assert_array_equal(np.shape(P([0])), (1,1))
+ assert_array_equal(np.shape(P([0,1])), (2,1))
+
def test_shapes_vectorvalue_derivative(self):
P = PiecewisePolynomial(self.xi,np.multiply.outer(self.yi,np.arange(3)),4)
n = 4
More information about the Scipy-svn
mailing list