[Scipy-svn] r3215 - in trunk/Lib/sparse: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Wed Aug 1 09:41:55 EDT 2007


Author: stefan
Date: 2007-08-01 08:41:33 -0500 (Wed, 01 Aug 2007)
New Revision: 3215

Modified:
   trunk/Lib/sparse/sparse.py
   trunk/Lib/sparse/tests/test_sparse.py
Log:
Add scalar addition and point-wise multiplication to lil_matrix.


Modified: trunk/Lib/sparse/sparse.py
===================================================================
--- trunk/Lib/sparse/sparse.py	2007-08-01 13:13:03 UTC (rev 3214)
+++ trunk/Lib/sparse/sparse.py	2007-08-01 13:41:33 UTC (rev 3215)
@@ -10,7 +10,6 @@
             'spdiags','speye','spidentity', 
             'isspmatrix','issparse','isspmatrix_csc','isspmatrix_csr',
             'isspmatrix_lil','isspmatrix_dok' ]
-                        
 
 import warnings
 
@@ -2413,13 +2412,53 @@
         else:
             return self.dot(other)
 
+    def multiply(self, other):
+        """Point-wise multiplication by another lil_matrix.
+
+        """
+        if isscalar(other):
+            return self.__mul__(other)
+
+        if isspmatrix_lil(other):
+            reference,target = self,other
+
+            if reference.shape != target.shape:
+                raise ValueError("Dimensions do not match.")
+
+            if len(reference.data) > len(target.data):
+                reference,target = target,reference
+
+            new = lil_matrix(reference.shape)
+            for r,row in enumerate(reference.rows):
+                tr = target.rows[r]
+                td = target.data[r]
+                rd = reference.data[r]
+                L = len(tr)
+                for c,column in enumerate(row):
+                    ix = bisect_left(tr,column)
+                    if ix < L and tr[ix] == column:
+                        new.rows[r].append(column)
+                        new.data[r].append(rd[c] * td[ix])
+            return new
+        else:
+            raise ValueError("Point-wise multiplication only allowed "
+                             "with another lil_matrix.")
+
     def copy(self):
         new = lil_matrix(self.shape, dtype=self.dtype)
         new.data = copy.deepcopy(self.data)
         new.rows = copy.deepcopy(self.rows)
         return new
-    
-    
+
+    def __add__(self, other):
+        if isscalar(other):
+            new = self.copy()
+            new.data = numpy.array([[val+other for val in rowvals] for
+                                    rowvals in new.data], dtype=object)
+            return new
+        else:
+            return spmatrix.__add__(self, other)
+
     def __rmul__(self, other):          # other * self
         if isscalarlike(other):
             # Multiplication by a scalar is symmetric

Modified: trunk/Lib/sparse/tests/test_sparse.py
===================================================================
--- trunk/Lib/sparse/tests/test_sparse.py	2007-08-01 13:13:03 UTC (rev 3214)
+++ trunk/Lib/sparse/tests/test_sparse.py	2007-08-01 13:41:33 UTC (rev 3215)
@@ -779,7 +779,7 @@
 
 class test_lil(_test_cs, _test_horiz_slicing, NumpyTestCase):
     spmatrix = lil_matrix
-    def check_mult(self):
+    def check_dot(self):
         A = matrix(zeros((10,10)))
         A[0,3] = 10
         A[5,6] = 20
@@ -829,7 +829,43 @@
         D = lil_matrix(C)
         assert_array_equal(C.A, D.A)
 
+    def check_scalar_add(self):
+        a = lil_matrix((3,3))
+        a[0,0] = 1
+        a[0,1] = 2
+        a[1,1] = 3
+        a[2,1] = 4
+        a[2,2] = 5
 
+        assert_array_equal((a-5).todense(),
+                           [[-4,-3,0],
+                            [ 0,-2,0],
+                            [ 0,-1,0]])
+
+    def check_point_wise_multiply(self):
+        l = lil_matrix((4,3))
+        l[0,0] = 1
+        l[1,1] = 2
+        l[2,2] = 3
+        l[3,1] = 4
+
+        m = lil_matrix((4,3))
+        m[0,0] = 1
+        m[0,1] = 2
+        m[2,2] = 3
+        m[3,1] = 4
+        m[3,2] = 4
+
+        assert_array_equal(l.multiply(m).todense(),
+                           m.multiply(l).todense())
+
+        assert_array_equal(l.multiply(m).todense(),
+                           [[1,0,0],
+                            [0,0,0],
+                            [0,0,9],
+                            [0,16,0]])
+
+
 class test_construct_utils(NumpyTestCase):
     def check_identity(self):
         a = spidentity(3)




More information about the Scipy-svn mailing list