[Scipy-svn] r3835 - trunk/scipy/sparse
scipy-svn at scipy.org
scipy-svn at scipy.org
Tue Jan 15 06:51:12 EST 2008
Author: wnbell
Date: 2008-01-15 05:51:10 -0600 (Tue, 15 Jan 2008)
New Revision: 3835
Modified:
trunk/scipy/sparse/bsr.py
Log:
use CSR matmat for 1x1 BSR
Modified: trunk/scipy/sparse/bsr.py
===================================================================
--- trunk/scipy/sparse/bsr.py 2008-01-14 21:31:13 UTC (rev 3834)
+++ trunk/scipy/sparse/bsr.py 2008-01-15 11:51:10 UTC (rev 3835)
@@ -7,13 +7,14 @@
from numpy import zeros, intc, array, asarray, arange, diff, tile, rank, \
prod, ravel, empty, matrix, asmatrix, empty_like, hstack
-import sparsetools
-from sparsetools import bsr_matvec, csr_matmat_pass1, bsr_matmat_pass2
from data import _data_matrix
from compressed import _cs_matrix
from base import isspmatrix, _formats
from sputils import isshape, getdtype, to_native, isscalarlike, isdense, \
upcast
+import sparsetools
+from sparsetools import bsr_matvec, csr_matmat_pass1, csr_matmat_pass2, \
+ bsr_matmat_pass2
class bsr_matrix(_cs_matrix):
"""Block Sparse Row matrix
@@ -329,13 +330,21 @@
indptr = empty_like( self.indptr )
R,n = self.blocksize
-
+
+ #convert to this format
if isspmatrix_bsr(other):
C = other.blocksize[1]
else:
C = 1
- other = other.tobsr(blocksize=(n,C)) #convert to this format
+ from csr import isspmatrix_csr
+
+ if isspmatrix_csr(other) and n == 1:
+ other = other.tobsr(blocksize=(n,C),copy=False) #convert to this format
+ else:
+ other = other.tobsr(blocksize=(n,C))
+
+
csr_matmat_pass1( M/R, N/C, \
self.indptr, self.indices, \
other.indptr, other.indices, \
@@ -345,10 +354,17 @@
indices = empty( bnnz, dtype=intc)
data = empty( R*C*bnnz, dtype=upcast(self.dtype,other.dtype))
- bsr_matmat_pass2( M/R, N/C, R, C, n, \
- self.indptr, self.indices, ravel(self.data), \
- other.indptr, other.indices, ravel(other.data), \
- indptr, indices, data)
+ if (R,C,n) == (1,1,1):
+ #use CSR * CSR when possible
+ csr_matmat_pass2( M, N, \
+ self.indptr, self.indices, ravel(self.data), \
+ other.indptr, other.indices, ravel(other.data), \
+ indptr, indices, data)
+ else:
+ bsr_matmat_pass2( M/R, N/C, R, C, n, \
+ self.indptr, self.indices, ravel(self.data), \
+ other.indptr, other.indices, ravel(other.data), \
+ indptr, indices, data)
data = data.reshape(-1,R,C)
#TODO eliminate zeros
More information about the Scipy-svn
mailing list