[Scipy-svn] r3685 - in trunk/scipy/sparse: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Dec 18 14:54:40 EST 2007


Author: wnbell
Date: 2007-12-18 13:54:33 -0600 (Tue, 18 Dec 2007)
New Revision: 3685

Modified:
   trunk/scipy/sparse/base.py
   trunk/scipy/sparse/compressed.py
   trunk/scipy/sparse/tests/test_base.py
   trunk/scipy/sparse/tests/test_sparse.py
Log:
allow matvec() to use preallocated storage
added associated unittests


Modified: trunk/scipy/sparse/base.py
===================================================================
--- trunk/scipy/sparse/base.py	2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/base.py	2007-12-18 19:54:33 UTC (rev 3685)
@@ -4,7 +4,7 @@
 
 from warnings import warn
 
-from numpy import asarray, asmatrix, ones
+from numpy import asarray, asmatrix, asanyarray, ones
 
 from sputils import isdense, isscalarlike 
 
@@ -321,7 +321,7 @@
             other.shape
         except AttributeError:
             # If it's a list or whatever, treat it like a matrix
-            other = asmatrix(other)
+            other = asanyarray(other)
 
         if isdense(other) and asarray(other).squeeze().ndim <= 1:
             # it's a dense row or column vector

Modified: trunk/scipy/sparse/compressed.py
===================================================================
--- trunk/scipy/sparse/compressed.py	2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/compressed.py	2007-12-18 19:54:33 UTC (rev 3685)
@@ -333,11 +333,7 @@
             fn( M, N, self.indptr, self.indices, self.data, \
                       other.indptr, other.indices, other.data, \
                       indptr, indices, data)
-
-            nnz = indptr[-1] #may have changed
-            #indices = indices[:nnz]
-            #data    = indices[:nnz]
-
+            
             return self.__class__((data,indices,indptr),shape=(M,N))
 
 
@@ -349,17 +345,40 @@
         else:
             raise TypeError, "need a dense or sparse matrix"
 
-    def matvec(self, other):
+    def matvec(self, other, output=None):
+        """Sparse matrix vector product (self * other)
+
+        'other' may be a rank 1 array of length N or a rank 2 array 
+        or matrix with shape (N,1).  
+        
+        If the optional 'output' parameter is defined, it will
+        be used to store the result.  Otherwise, a new vector
+        will be allocated.
+             
+        """
         if isdense(other):
-            if other.size != self.shape[1] or \
-                    (other.ndim == 2 and self.shape[1] != other.shape[0]):
+            M,N = self.shape
+
+            if other.shape != (N,) and other.shape != (N,1):
                 raise ValueError, "dimension mismatch"
 
             # csrmux, cscmux
             fn = getattr(sparsetools,self.format + '_matvec')
     
             #output array
-            y = empty( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
+            if output is None:
+                y = empty( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
+            else:
+                if output.shape != (M,) and output.shape != (M,1):
+                    print "self ",self.shape,"other",other.shape
+                    raise ValueError, "output array has improper dimensions"
+                if not output.flags.c_contiguous:
+                    raise ValueError, "output array must be contiguous"
+                if output.dtype != upcast(self.dtype,other.dtype):
+                    raise ValueError, "output array has dtype=%s "\
+                            "dtype=%s is required" % \
+                            (output.dtype,upcast(self.dtype,other.dtype))
+                y = output
 
             fn(self.shape[0], self.shape[1], \
                 self.indptr, self.indices, self.data, numpy.ravel(other), y)

Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py	2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/tests/test_base.py	2007-12-18 19:54:33 UTC (rev 3685)
@@ -428,7 +428,39 @@
         assert_array_equal(self.dat/17.3,a.todense())
 
 
+class _TestMatvecOutput:
+    """test using the matvec() output parameter"""
+    def check_matvec_output(self): 
+        #flat array
+        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+        y = zeros(3,dtype='d')
+        
+        self.datsp.matvec(x,y)
+        assert_array_equal(self.datsp*x,y)
+    
+        #column vector
+        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+        x = x.reshape(4,1)
+        y = zeros((3,1),dtype='d')
 
+        self.datsp.matvec(x,y)
+        assert_array_equal(self.datsp*x,y)
+   
+        # improper output type
+        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+        y = zeros(3,dtype='i')
+        
+        self.assertRaises( ValueError, self.datsp.matvec, x, y )
+        
+        # proper upcast output type
+        x = array([1.25, -6.5, 0.125, -3.75],dtype='complex64')
+        x.imag = [1,2,3,4]
+        y = zeros(3,dtype='complex128')
+       
+        self.datsp.matvec(x,y)
+        assert_array_equal(self.datsp*x,y)
+        assert_equal((self.datsp*x).dtype,y.dtype)
+
 class _TestGetSet:
     def check_setelement(self):
         a = self.datsp - self.datsp
@@ -674,7 +706,7 @@
 
 
 class TestCSR(_TestCommon, _TestGetSet, _TestSolve,
-        _TestInplaceArithmetic, _TestArithmetic,  
+        _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
         _TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
         NumpyTestCase):
     spmatrix = csr_matrix
@@ -771,7 +803,7 @@
         assert_equal( ab, aa[i0,i1[0]:i1[1]] )
 
 class TestCSC(_TestCommon, _TestGetSet, _TestSolve,
-        _TestInplaceArithmetic, _TestArithmetic,  
+        _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
         _TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
         NumpyTestCase):
     spmatrix = csc_matrix

Modified: trunk/scipy/sparse/tests/test_sparse.py
===================================================================
--- trunk/scipy/sparse/tests/test_sparse.py	2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/tests/test_sparse.py	2007-12-18 19:54:33 UTC (rev 3685)
@@ -59,7 +59,11 @@
             start = time.clock()
             iter = 0
             while iter < 5 or time.clock() < start + 1:
-                y = A*x
+                try:
+                    #avoid creating y if possible
+                    A.matvec(x,y)
+                except:
+                    y = A*x
                 iter += 1
             end = time.clock()
 
@@ -91,7 +95,7 @@
                 start = time.clock()
                 
                 iter = 0
-                while time.clock() < start + 0.1:
+                while time.clock() < start + 0.5:
                     T = eval(format + '_matrix')(A.shape)
                     for i,j,v in zip(A.row,A.col,A.data):
                         T[i,j] = v




More information about the Scipy-svn mailing list