[Scipy-svn] r3215 - in trunk/Lib/sparse: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Aug 1 09:41:55 EDT 2007
Author: stefan
Date: 2007-08-01 08:41:33 -0500 (Wed, 01 Aug 2007)
New Revision: 3215
Modified:
trunk/Lib/sparse/sparse.py
trunk/Lib/sparse/tests/test_sparse.py
Log:
Add scalar addition and point-wise multiplication to lil_matrix.
Modified: trunk/Lib/sparse/sparse.py
===================================================================
--- trunk/Lib/sparse/sparse.py 2007-08-01 13:13:03 UTC (rev 3214)
+++ trunk/Lib/sparse/sparse.py 2007-08-01 13:41:33 UTC (rev 3215)
@@ -10,7 +10,6 @@
'spdiags','speye','spidentity',
'isspmatrix','issparse','isspmatrix_csc','isspmatrix_csr',
'isspmatrix_lil','isspmatrix_dok' ]
-
import warnings
@@ -2413,13 +2412,53 @@
else:
return self.dot(other)
+ def multiply(self, other):
+ """Point-wise multiplication by another lil_matrix.
+
+ """
+ if isscalar(other):
+ return self.__mul__(other)
+
+ if isspmatrix_lil(other):
+ reference,target = self,other
+
+ if reference.shape != target.shape:
+ raise ValueError("Dimensions do not match.")
+
+ if len(reference.data) > len(target.data):
+ reference,target = target,reference
+
+ new = lil_matrix(reference.shape)
+ for r,row in enumerate(reference.rows):
+ tr = target.rows[r]
+ td = target.data[r]
+ rd = reference.data[r]
+ L = len(tr)
+ for c,column in enumerate(row):
+ ix = bisect_left(tr,column)
+ if ix < L and tr[ix] == column:
+ new.rows[r].append(column)
+ new.data[r].append(rd[c] * td[ix])
+ return new
+ else:
+ raise ValueError("Point-wise multiplication only allowed "
+ "with another lil_matrix.")
+
def copy(self):
new = lil_matrix(self.shape, dtype=self.dtype)
new.data = copy.deepcopy(self.data)
new.rows = copy.deepcopy(self.rows)
return new
-
-
+
+ def __add__(self, other):
+ if isscalar(other):
+ new = self.copy()
+ new.data = numpy.array([[val+other for val in rowvals] for
+ rowvals in new.data], dtype=object)
+ return new
+ else:
+ return spmatrix.__add__(self, other)
+
def __rmul__(self, other): # other * self
if isscalarlike(other):
# Multiplication by a scalar is symmetric
Modified: trunk/Lib/sparse/tests/test_sparse.py
===================================================================
--- trunk/Lib/sparse/tests/test_sparse.py 2007-08-01 13:13:03 UTC (rev 3214)
+++ trunk/Lib/sparse/tests/test_sparse.py 2007-08-01 13:41:33 UTC (rev 3215)
@@ -779,7 +779,7 @@
class test_lil(_test_cs, _test_horiz_slicing, NumpyTestCase):
spmatrix = lil_matrix
- def check_mult(self):
+ def check_dot(self):
A = matrix(zeros((10,10)))
A[0,3] = 10
A[5,6] = 20
@@ -829,7 +829,43 @@
D = lil_matrix(C)
assert_array_equal(C.A, D.A)
+ def check_scalar_add(self):
+ a = lil_matrix((3,3))
+ a[0,0] = 1
+ a[0,1] = 2
+ a[1,1] = 3
+ a[2,1] = 4
+ a[2,2] = 5
+ assert_array_equal((a-5).todense(),
+ [[-4,-3,0],
+ [ 0,-2,0],
+ [ 0,-1,0]])
+
+ def check_point_wise_multiply(self):
+ l = lil_matrix((4,3))
+ l[0,0] = 1
+ l[1,1] = 2
+ l[2,2] = 3
+ l[3,1] = 4
+
+ m = lil_matrix((4,3))
+ m[0,0] = 1
+ m[0,1] = 2
+ m[2,2] = 3
+ m[3,1] = 4
+ m[3,2] = 4
+
+ assert_array_equal(l.multiply(m).todense(),
+ m.multiply(l).todense())
+
+ assert_array_equal(l.multiply(m).todense(),
+ [[1,0,0],
+ [0,0,0],
+ [0,0,9],
+ [0,16,0]])
+
+
class test_construct_utils(NumpyTestCase):
def check_identity(self):
a = spidentity(3)
More information about the Scipy-svn
mailing list