[Scipy-svn] r3835 - trunk/scipy/sparse

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Jan 15 06:51:12 EST 2008


Author: wnbell
Date: 2008-01-15 05:51:10 -0600 (Tue, 15 Jan 2008)
New Revision: 3835

Modified:
   trunk/scipy/sparse/bsr.py
Log:
use CSR matmat for 1x1 BSR


Modified: trunk/scipy/sparse/bsr.py
===================================================================
--- trunk/scipy/sparse/bsr.py	2008-01-14 21:31:13 UTC (rev 3834)
+++ trunk/scipy/sparse/bsr.py	2008-01-15 11:51:10 UTC (rev 3835)
@@ -7,13 +7,14 @@
 from numpy import zeros, intc, array, asarray, arange, diff, tile, rank, \
         prod, ravel, empty, matrix, asmatrix, empty_like, hstack
 
-import sparsetools
-from sparsetools import bsr_matvec, csr_matmat_pass1, bsr_matmat_pass2
 from data import _data_matrix
 from compressed import _cs_matrix
 from base import isspmatrix, _formats
 from sputils import isshape, getdtype, to_native, isscalarlike, isdense, \
         upcast
+import sparsetools
+from sparsetools import bsr_matvec, csr_matmat_pass1, csr_matmat_pass2, \
+        bsr_matmat_pass2
 
 class bsr_matrix(_cs_matrix):
     """Block Sparse Row matrix
@@ -329,13 +330,21 @@
             indptr = empty_like( self.indptr )
             
             R,n = self.blocksize
-
+            
+            #convert to this format
             if isspmatrix_bsr(other):
                 C = other.blocksize[1]
             else:
                 C = 1
-            other = other.tobsr(blocksize=(n,C)) #convert to this format
 
+            from csr import isspmatrix_csr
+
+            if isspmatrix_csr(other) and n == 1:
+                other = other.tobsr(blocksize=(n,C),copy=False) #convert to this format
+            else:
+                other = other.tobsr(blocksize=(n,C))
+
+
             csr_matmat_pass1( M/R, N/C, \
                     self.indptr,  self.indices, \
                     other.indptr, other.indices, \
@@ -345,10 +354,17 @@
             indices = empty( bnnz, dtype=intc)
             data    = empty( R*C*bnnz, dtype=upcast(self.dtype,other.dtype))
 
-            bsr_matmat_pass2( M/R, N/C, R, C, n, \
-                    self.indptr,  self.indices,  ravel(self.data), \
-                    other.indptr, other.indices, ravel(other.data), \
-                    indptr,       indices,       data)
+            if (R,C,n) == (1,1,1):
+                #use CSR * CSR when possible
+                csr_matmat_pass2( M, N, \
+                        self.indptr,  self.indices,  ravel(self.data), \
+                        other.indptr, other.indices, ravel(other.data), \
+                        indptr,       indices,       data)
+            else:
+                bsr_matmat_pass2( M/R, N/C, R, C, n, \
+                        self.indptr,  self.indices,  ravel(self.data), \
+                        other.indptr, other.indices, ravel(other.data), \
+                        indptr,       indices,       data)
             
             data = data.reshape(-1,R,C)
             #TODO eliminate zeros




More information about the Scipy-svn mailing list