[Scipy-svn] r5117 - in trunk/scipy/sparse/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Nov 15 16:09:31 EST 2008


Author: wnbell
Date: 2008-11-15 15:09:27 -0600 (Sat, 15 Nov 2008)
New Revision: 5117

Modified:
   trunk/scipy/sparse/linalg/interface.py
   trunk/scipy/sparse/linalg/tests/test_interface.py
Log:
added __mul__ to LinearOperator


Modified: trunk/scipy/sparse/linalg/interface.py
===================================================================
--- trunk/scipy/sparse/linalg/interface.py	2008-11-15 12:28:35 UTC (rev 5116)
+++ trunk/scipy/sparse/linalg/interface.py	2008-11-15 21:09:27 UTC (rev 5117)
@@ -48,6 +48,8 @@
     <2x2 LinearOperator with unspecified dtype>
     >>> A.matvec( ones(2) )
     array([ 2.,  3.])
+    >>> A * ones(2)
+    array([ 2.,  3.])
 
     """
     def __init__( self, shape, matvec, rmatvec=None, matmat=None, dtype=None ):
@@ -79,6 +81,14 @@
         if dtype is not None:
             self.dtype = numpy.dtype(dtype)
 
+    def __mul__(self,x):
+        x = numpy.asarray(x)
+
+        if numpy.rank(x.squeeze()) == 1:
+            return self.matvec(x)
+        else:
+            return self.matmat(x)
+
     def __repr__(self):
         M,N = self.shape
         if hasattr(self,'dtype'):

Modified: trunk/scipy/sparse/linalg/tests/test_interface.py
===================================================================
--- trunk/scipy/sparse/linalg/tests/test_interface.py	2008-11-15 12:28:35 UTC (rev 5116)
+++ trunk/scipy/sparse/linalg/tests/test_interface.py	2008-11-15 21:09:27 UTC (rev 5117)
@@ -44,12 +44,16 @@
 
             assert_equal(A.matvec(array([1,2,3])),       [14,32])
             assert_equal(A.matvec(array([[1],[2],[3]])), [[14],[32]])
+            
+            assert_equal(A * array([1,2,3]),       [14,32])
+            assert_equal(A * array([[1],[2],[3]]), [[14],[32]])
 
             assert_equal(A.rmatvec(array([1,2])),     [9,12,15])
             assert_equal(A.rmatvec(array([[1],[2]])), [[9],[12],[15]])
 
-            assert_equal(A.matmat(array([[1,4],[2,5],[3,6]])), \
-                    [[14,32],[32,77]] )
+            assert_equal(A.matmat(array([[1,4],[2,5],[3,6]])), [[14,32],[32,77]] )
+            
+            assert_equal(A * array([[1,4],[2,5],[3,6]]), [[14,32],[32,77]] )
 
             if hasattr(M,'dtype'):
                 assert_equal(A.dtype, M.dtype)




More information about the Scipy-svn mailing list