[Scipy-svn] r5172 - in trunk/scipy/sparse/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Nov 22 18:14:50 EST 2008


Author: wnbell
Date: 2008-11-22 17:14:44 -0600 (Sat, 22 Nov 2008)
New Revision: 5172

Modified:
   trunk/scipy/sparse/linalg/interface.py
   trunk/scipy/sparse/linalg/tests/test_interface.py
Log:
force rank-1 ndarray output when input is rank-1 ndarray


Modified: trunk/scipy/sparse/linalg/interface.py
===================================================================
--- trunk/scipy/sparse/linalg/interface.py	2008-11-22 22:47:13 UTC (rev 5171)
+++ trunk/scipy/sparse/linalg/interface.py	2008-11-22 23:14:44 UTC (rev 5172)
@@ -121,14 +121,20 @@
             raise ValueError('dimension mismatch')
 
         y = self._matvec(x)
-
-        if x.ndim == 2:
-            # If 'x' is a column vector, reshape the result
-            y = y.reshape(-1,1)
-
+        
         if isinstance(x, np.matrix):
             y = np.asmatrix(y)
+        else:
+            y = np.asarray(y)
 
+        if x.ndim == 1:
+            y = y.reshape(M)
+        elif x.ndim == 2:
+            y = y.reshape(M,1)
+        else:
+            raise ValueError('invalid shape returned by user-defined matvec()')
+
+
         return y
 
 

Modified: trunk/scipy/sparse/linalg/tests/test_interface.py
===================================================================
--- trunk/scipy/sparse/linalg/tests/test_interface.py	2008-11-22 22:47:13 UTC (rev 5171)
+++ trunk/scipy/sparse/linalg/tests/test_interface.py	2008-11-22 23:14:44 UTC (rev 5172)
@@ -10,37 +10,47 @@
 
 
 class TestLinearOperator(TestCase):
-    def test_matvec(self):
-        def matvec(x):
-            # note, this matvec does not preserve type or shape
-            y = np.array([ 1*x[0] + 2*x[1] + 3*x[2],
-                           4*x[0] + 5*x[1] + 6*x[2]])
-            return y
+    def setUp(self):
+        self.matvecs = []
 
-        A = LinearOperator((2,3), matvec)
+        # these matvecs do not preserve type or shape
+        def matvec1(x):
+            return np.array([ 1*x[0] + 2*x[1] + 3*x[2],
+                              4*x[0] + 5*x[1] + 6*x[2]])
+        def matvec2(x):
+            return np.matrix(matvec1(x).reshape(2,1))
 
-        assert_equal(A.matvec(np.array([1,2,3])),       [14,32])
-        assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
-        assert_equal(A * np.array([1,2,3]),             [14,32])
-        assert_equal(A * np.array([[1],[2],[3]]),       [[14],[32]])
-        
-        assert_equal(A.matvec(np.matrix([[1],[2],[3]])), [[14],[32]])
-        assert_equal(A * np.matrix([[1],[2],[3]]),       [[14],[32]])
+        self.matvecs.append(matvec1)
+        self.matvecs.append(matvec2)
 
-        assert( isinstance(A.matvec(np.array([1,2,3])),       np.ndarray) )
-        assert( isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray) )
-        assert( isinstance(A * np.array([1,2,3]),             np.ndarray) )
-        assert( isinstance(A * np.array([[1],[2],[3]]),       np.ndarray) )
+    def test_matvec(self):
 
-        assert( isinstance(A.matvec(np.matrix([[1],[2],[3]])), np.ndarray) )
-        assert( isinstance(A * np.matrix([[1],[2],[3]]),       np.ndarray) )
+        for matvec in self.matvecs:
+            A = LinearOperator((2,3), matvec)
+    
+            assert_equal(A.matvec(np.array([1,2,3])),       [14,32])
+            assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
+            assert_equal(A * np.array([1,2,3]),             [14,32])
+            assert_equal(A * np.array([[1],[2],[3]]),       [[14],[32]])
+            
+            assert_equal(A.matvec(np.matrix([[1],[2],[3]])), [[14],[32]])
+            assert_equal(A * np.matrix([[1],[2],[3]]),       [[14],[32]])
+    
+            assert( isinstance(A.matvec(np.array([1,2,3])),       np.ndarray) )
+            assert( isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray) )
+            assert( isinstance(A * np.array([1,2,3]),             np.ndarray) )
+            assert( isinstance(A * np.array([[1],[2],[3]]),       np.ndarray) )
+    
+            assert( isinstance(A.matvec(np.matrix([[1],[2],[3]])), np.ndarray) )
+            assert( isinstance(A * np.matrix([[1],[2],[3]]),       np.ndarray) )
+    
+            assert_raises(ValueError, A.matvec, np.array([1,2]))
+            assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
+            assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
+            assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
+        
 
-        assert_raises(ValueError, A.matvec, np.array([1,2]))
-        assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
-        assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
-        assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
 
-
 class TestAsLinearOperator(TestCase):
     def setUp(self):
         self.cases = []




More information about the Scipy-svn mailing list