[Scipy-svn] r4513 - in trunk/scipy/sparse: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Jul 1 03:41:41 EDT 2008


Author: wnbell
Date: 2008-07-01 02:40:59 -0500 (Tue, 01 Jul 2008)
New Revision: 4513

Modified:
   trunk/scipy/sparse/base.py
   trunk/scipy/sparse/tests/test_base.py
Log:
added .nonzero() to sparse matrices
resolves ticket #695


Modified: trunk/scipy/sparse/base.py
===================================================================
--- trunk/scipy/sparse/base.py	2008-07-01 04:58:48 UTC (rev 4512)
+++ trunk/scipy/sparse/base.py	2008-07-01 07:40:59 UTC (rev 4513)
@@ -335,6 +335,29 @@
     def _imag(self):
         return self.tocsr()._imag()
 
+
+    def nonzero(self):
+        """nonzero indices
+        
+        Returns a tuple of arrays (row,col) containing the indices
+        of the non-zero elements of the matrix.
+
+        Example
+        -------
+
+        >>> from scipy.sparse import csr_matrix
+        >>> A = csr_matrix([[1,2,0],[0,0,3],[4,0,5]])
+        >>> A.nonzero()
+        (array([0, 0, 1, 2, 2]), array([0, 1, 2, 0, 2]))
+
+        """
+
+        # convert to COOrdinate format
+        A = self.tocoo()
+        nz_mask = A.data != 0 
+        return (A.row[nz_mask],A.col[nz_mask])
+
+
     def getcol(self, j):
         """Returns a copy of column j of the matrix, as an (m x 1) sparse
         matrix (column vector).

Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py	2008-07-01 04:58:48 UTC (rev 4512)
+++ trunk/scipy/sparse/tests/test_base.py	2008-07-01 07:40:59 UTC (rev 4513)
@@ -105,6 +105,16 @@
             assert_equal(self.spmatrix(m).diagonal(),diag(m))
 
 
+    def test_nonzero(self):
+        A   = array([[1, 0, 1],[0, 1, 1],[ 0, 0, 1]])
+        Asp = self.spmatrix(A)
+
+        A_nz   = set( [tuple(ij) for ij in transpose(A.nonzero())] )
+        Asp_nz = set( [tuple(ij) for ij in transpose(Asp.nonzero())] )
+
+        assert_equal(A_nz, Asp_nz)
+
+
     def test_getrow(self):
         assert_array_equal(self.datsp.getrow(1).todense(), self.dat[1,:])
     




More information about the Scipy-svn mailing list