[Scipy-svn] r5172 - in trunk/scipy/sparse/linalg: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Nov 22 18:14:50 EST 2008
Author: wnbell
Date: 2008-11-22 17:14:44 -0600 (Sat, 22 Nov 2008)
New Revision: 5172
Modified:
trunk/scipy/sparse/linalg/interface.py
trunk/scipy/sparse/linalg/tests/test_interface.py
Log:
force rank-1 ndarray output when input is rank-1 ndarray
Modified: trunk/scipy/sparse/linalg/interface.py
===================================================================
--- trunk/scipy/sparse/linalg/interface.py 2008-11-22 22:47:13 UTC (rev 5171)
+++ trunk/scipy/sparse/linalg/interface.py 2008-11-22 23:14:44 UTC (rev 5172)
@@ -121,14 +121,20 @@
raise ValueError('dimension mismatch')
y = self._matvec(x)
-
- if x.ndim == 2:
- # If 'x' is a column vector, reshape the result
- y = y.reshape(-1,1)
-
+
if isinstance(x, np.matrix):
y = np.asmatrix(y)
+ else:
+ y = np.asarray(y)
+ if x.ndim == 1:
+ y = y.reshape(M)
+ elif x.ndim == 2:
+ y = y.reshape(M,1)
+ else:
+ raise ValueError('invalid shape returned by user-defined matvec()')
+
+
return y
Modified: trunk/scipy/sparse/linalg/tests/test_interface.py
===================================================================
--- trunk/scipy/sparse/linalg/tests/test_interface.py 2008-11-22 22:47:13 UTC (rev 5171)
+++ trunk/scipy/sparse/linalg/tests/test_interface.py 2008-11-22 23:14:44 UTC (rev 5172)
@@ -10,37 +10,47 @@
class TestLinearOperator(TestCase):
- def test_matvec(self):
- def matvec(x):
- # note, this matvec does not preserve type or shape
- y = np.array([ 1*x[0] + 2*x[1] + 3*x[2],
- 4*x[0] + 5*x[1] + 6*x[2]])
- return y
+ def setUp(self):
+ self.matvecs = []
- A = LinearOperator((2,3), matvec)
+ # these matvecs do not preserve type or shape
+ def matvec1(x):
+ return np.array([ 1*x[0] + 2*x[1] + 3*x[2],
+ 4*x[0] + 5*x[1] + 6*x[2]])
+ def matvec2(x):
+ return np.matrix(matvec1(x).reshape(2,1))
- assert_equal(A.matvec(np.array([1,2,3])), [14,32])
- assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
- assert_equal(A * np.array([1,2,3]), [14,32])
- assert_equal(A * np.array([[1],[2],[3]]), [[14],[32]])
-
- assert_equal(A.matvec(np.matrix([[1],[2],[3]])), [[14],[32]])
- assert_equal(A * np.matrix([[1],[2],[3]]), [[14],[32]])
+ self.matvecs.append(matvec1)
+ self.matvecs.append(matvec2)
- assert( isinstance(A.matvec(np.array([1,2,3])), np.ndarray) )
- assert( isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray) )
- assert( isinstance(A * np.array([1,2,3]), np.ndarray) )
- assert( isinstance(A * np.array([[1],[2],[3]]), np.ndarray) )
+ def test_matvec(self):
- assert( isinstance(A.matvec(np.matrix([[1],[2],[3]])), np.ndarray) )
- assert( isinstance(A * np.matrix([[1],[2],[3]]), np.ndarray) )
+ for matvec in self.matvecs:
+ A = LinearOperator((2,3), matvec)
+
+ assert_equal(A.matvec(np.array([1,2,3])), [14,32])
+ assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
+ assert_equal(A * np.array([1,2,3]), [14,32])
+ assert_equal(A * np.array([[1],[2],[3]]), [[14],[32]])
+
+ assert_equal(A.matvec(np.matrix([[1],[2],[3]])), [[14],[32]])
+ assert_equal(A * np.matrix([[1],[2],[3]]), [[14],[32]])
+
+ assert( isinstance(A.matvec(np.array([1,2,3])), np.ndarray) )
+ assert( isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray) )
+ assert( isinstance(A * np.array([1,2,3]), np.ndarray) )
+ assert( isinstance(A * np.array([[1],[2],[3]]), np.ndarray) )
+
+ assert( isinstance(A.matvec(np.matrix([[1],[2],[3]])), np.ndarray) )
+ assert( isinstance(A * np.matrix([[1],[2],[3]]), np.ndarray) )
+
+ assert_raises(ValueError, A.matvec, np.array([1,2]))
+ assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
+ assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
+ assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
+
- assert_raises(ValueError, A.matvec, np.array([1,2]))
- assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
- assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
- assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
-
class TestAsLinearOperator(TestCase):
def setUp(self):
self.cases = []
More information about the Scipy-svn
mailing list