[Scipy-svn] r3685 - in trunk/scipy/sparse: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Tue Dec 18 14:54:40 EST 2007
Author: wnbell
Date: 2007-12-18 13:54:33 -0600 (Tue, 18 Dec 2007)
New Revision: 3685
Modified:
trunk/scipy/sparse/base.py
trunk/scipy/sparse/compressed.py
trunk/scipy/sparse/tests/test_base.py
trunk/scipy/sparse/tests/test_sparse.py
Log:
allow matvec() to use preallocated storage
added associated unittests
Modified: trunk/scipy/sparse/base.py
===================================================================
--- trunk/scipy/sparse/base.py 2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/base.py 2007-12-18 19:54:33 UTC (rev 3685)
@@ -4,7 +4,7 @@
from warnings import warn
-from numpy import asarray, asmatrix, ones
+from numpy import asarray, asmatrix, asanyarray, ones
from sputils import isdense, isscalarlike
@@ -321,7 +321,7 @@
other.shape
except AttributeError:
# If it's a list or whatever, treat it like a matrix
- other = asmatrix(other)
+ other = asanyarray(other)
if isdense(other) and asarray(other).squeeze().ndim <= 1:
# it's a dense row or column vector
Modified: trunk/scipy/sparse/compressed.py
===================================================================
--- trunk/scipy/sparse/compressed.py 2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/compressed.py 2007-12-18 19:54:33 UTC (rev 3685)
@@ -333,11 +333,7 @@
fn( M, N, self.indptr, self.indices, self.data, \
other.indptr, other.indices, other.data, \
indptr, indices, data)
-
- nnz = indptr[-1] #may have changed
- #indices = indices[:nnz]
- #data = indices[:nnz]
-
+
return self.__class__((data,indices,indptr),shape=(M,N))
@@ -349,17 +345,40 @@
else:
raise TypeError, "need a dense or sparse matrix"
- def matvec(self, other):
+ def matvec(self, other, output=None):
+ """Sparse matrix vector product (self * other)
+
+ 'other' may be a rank 1 array of length N or a rank 2 array
+ or matrix with shape (N,1).
+
+ If the optional 'output' parameter is defined, it will
+ be used to store the result. Otherwise, a new vector
+ will be allocated.
+
+ """
if isdense(other):
- if other.size != self.shape[1] or \
- (other.ndim == 2 and self.shape[1] != other.shape[0]):
+ M,N = self.shape
+
+ if other.shape != (N,) and other.shape != (N,1):
raise ValueError, "dimension mismatch"
# csrmux, cscmux
fn = getattr(sparsetools,self.format + '_matvec')
#output array
- y = empty( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
+ if output is None:
+ y = empty( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
+ else:
+ if output.shape != (M,) and output.shape != (M,1):
+ print "self ",self.shape,"other",other.shape
+ raise ValueError, "output array has improper dimensions"
+ if not output.flags.c_contiguous:
+ raise ValueError, "output array must be contiguous"
+ if output.dtype != upcast(self.dtype,other.dtype):
+ raise ValueError, "output array has dtype=%s "\
+ "dtype=%s is required" % \
+ (output.dtype,upcast(self.dtype,other.dtype))
+ y = output
fn(self.shape[0], self.shape[1], \
self.indptr, self.indices, self.data, numpy.ravel(other), y)
Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py 2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/tests/test_base.py 2007-12-18 19:54:33 UTC (rev 3685)
@@ -428,7 +428,39 @@
assert_array_equal(self.dat/17.3,a.todense())
+class _TestMatvecOutput:
+ """test using the matvec() output parameter"""
+ def check_matvec_output(self):
+ #flat array
+ x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+ y = zeros(3,dtype='d')
+
+ self.datsp.matvec(x,y)
+ assert_array_equal(self.datsp*x,y)
+
+ #column vector
+ x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+ x = x.reshape(4,1)
+ y = zeros((3,1),dtype='d')
+ self.datsp.matvec(x,y)
+ assert_array_equal(self.datsp*x,y)
+
+ # improper output type
+ x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
+ y = zeros(3,dtype='i')
+
+ self.assertRaises( ValueError, self.datsp.matvec, x, y )
+
+ # proper upcast output type
+ x = array([1.25, -6.5, 0.125, -3.75],dtype='complex64')
+ x.imag = [1,2,3,4]
+ y = zeros(3,dtype='complex128')
+
+ self.datsp.matvec(x,y)
+ assert_array_equal(self.datsp*x,y)
+ assert_equal((self.datsp*x).dtype,y.dtype)
+
class _TestGetSet:
def check_setelement(self):
a = self.datsp - self.datsp
@@ -674,7 +706,7 @@
class TestCSR(_TestCommon, _TestGetSet, _TestSolve,
- _TestInplaceArithmetic, _TestArithmetic,
+ _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
_TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
NumpyTestCase):
spmatrix = csr_matrix
@@ -771,7 +803,7 @@
assert_equal( ab, aa[i0,i1[0]:i1[1]] )
class TestCSC(_TestCommon, _TestGetSet, _TestSolve,
- _TestInplaceArithmetic, _TestArithmetic,
+ _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
_TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
NumpyTestCase):
spmatrix = csc_matrix
Modified: trunk/scipy/sparse/tests/test_sparse.py
===================================================================
--- trunk/scipy/sparse/tests/test_sparse.py 2007-12-18 08:43:14 UTC (rev 3684)
+++ trunk/scipy/sparse/tests/test_sparse.py 2007-12-18 19:54:33 UTC (rev 3685)
@@ -59,7 +59,11 @@
start = time.clock()
iter = 0
while iter < 5 or time.clock() < start + 1:
- y = A*x
+ try:
+ #avoid creating y if possible
+ A.matvec(x,y)
+ except:
+ y = A*x
iter += 1
end = time.clock()
@@ -91,7 +95,7 @@
start = time.clock()
iter = 0
- while time.clock() < start + 0.1:
+ while time.clock() < start + 0.5:
T = eval(format + '_matrix')(A.shape)
for i,j,v in zip(A.row,A.col,A.data):
T[i,j] = v
More information about the Scipy-svn
mailing list