[Scipy-svn] r4195 - in trunk/scipy/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Apr 28 00:59:05 EDT 2008


Author: peridot
Date: 2008-04-27 23:59:03 -0500 (Sun, 27 Apr 2008)
New Revision: 4195

Modified:
   trunk/scipy/interpolate/polyint.py
   trunk/scipy/interpolate/tests/test_polyint.py
Log:
Fix bug introduced in r4181; also PiecewisePolynomial now correctly distinguishes between scalar values and vectors of length 1.


Modified: trunk/scipy/interpolate/polyint.py
===================================================================
--- trunk/scipy/interpolate/polyint.py	2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/polyint.py	2008-04-28 04:59:03 UTC (rev 4195)
@@ -432,16 +432,20 @@
         derivatives needed is odd, it will prefer the rightmost endpoint. If 
         not enough derivatives are available, an exception is raised.
         """
+        yi0 = np.asarray(yi[0])
+        if len(yi0.shape)==2:
+            self.vector_valued = True
+            self.r = yi0.shape[1]
+        elif len(yi0.shape)==1:
+            self.vector_valued = False
+            self.r = 1
+        else:
+            raise ValueError, "Each derivative must be a vector, not a higher-rank array"
+
         self.xi = [xi[0]]
-        self.yi = [yi[0]]
+        self.yi = [yi0]
         self.n = 1
         
-        try:
-            self.r = len(yi[0][0])
-        except TypeError:
-            self.r = 1
-
-        self.n = 1
         self.direction = direction
         self.orders = []
         self.polynomials = []
@@ -468,7 +472,11 @@
         assert n2<=len(y2)
 
         xi = np.zeros(n)
-        yi = np.zeros((n,self.r))
+        if self.vector_valued:
+            yi = np.zeros((n,self.r))
+        else:
+            yi = np.zeros((n,))
+
         xi[:n1] = x1
         yi[:n1] = y1[:n1]
         xi[n1:] = x2
@@ -488,19 +496,23 @@
             a polynomial order, or instructions to use the highest 
             possible order
         """
+
+        yi = np.asarray(yi)
+        if self.vector_valued:
+            if (len(yi.shape)!=2 or yi.shape[1]!=self.r):
+                raise ValueError, "Each derivative must be a vector of length %d" % self.r
+        else:
+            if len(yi.shape)!=1:
+                raise ValueError, "Each derivative must be a scalar"
+
         if self.direction is None:
             self.direction = np.sign(xi-self.xi[-1])
         elif (xi-self.xi[-1])*self.direction < 0: 
             raise ValueError, "x coordinates must be in the %d direction: %s" % (self.direction, self.xi)
+
         self.xi.append(xi)
         self.yi.append(yi)
 
-        for y in yi:
-            if np.shape(y) != (self.r,):
-                if self.r>1:
-                    raise ValueError, "Each derivative must be a vector of length %d" % self.r
-                else:
-                    raise ValueError, "Each derivative must be a scalar"
 
         if order is None:
             n1 = len(self.yi[-2])
@@ -558,7 +570,7 @@
             x = np.asarray(x)
             m = len(x)
             pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
-            if self.r>1:
+            if self.vector_valued:
                 y = np.zeros((m,self.r))
             else:
                 y = np.zeros(m)
@@ -611,7 +623,7 @@
             x = np.asarray(x)
             m = len(x)
             pos = np.clip(np.searchsorted(self.xi, x) - 1, 0, self.n-2)
-            if self.r>1:
+            if self.vector_valued:
                 y = np.zeros((der,m,self.r))
             else:
                 y = np.zeros((der,m))

Modified: trunk/scipy/interpolate/tests/test_polyint.py
===================================================================
--- trunk/scipy/interpolate/tests/test_polyint.py	2008-04-27 13:29:06 UTC (rev 4194)
+++ trunk/scipy/interpolate/tests/test_polyint.py	2008-04-28 04:59:03 UTC (rev 4195)
@@ -219,6 +219,13 @@
         assert_array_equal(np.shape(P([0])), (1,3))
         assert_array_equal(np.shape(P([0,1])), (2,3))
 
+    def test_shapes_vectorvalue_1d(self):
+        yi = np.multiply.outer(np.asarray(self.yi),np.arange(1))
+        P = PiecewisePolynomial(self.xi,yi,4)
+        assert_array_equal(np.shape(P(0)), (1,))
+        assert_array_equal(np.shape(P([0])), (1,1))
+        assert_array_equal(np.shape(P([0,1])), (2,1))
+
     def test_shapes_vectorvalue_derivative(self):
         P = PiecewisePolynomial(self.xi,np.multiply.outer(self.yi,np.arange(3)),4)
         n = 4




More information about the Scipy-svn mailing list