[Scipy-svn] r5110 - in trunk/scipy/sparse: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Nov 14 22:45:27 EST 2008


Author: wnbell
Date: 2008-11-14 21:45:23 -0600 (Fri, 14 Nov 2008)
New Revision: 5110

Modified:
   trunk/scipy/sparse/compressed.py
   trunk/scipy/sparse/dok.py
   trunk/scipy/sparse/lil.py
   trunk/scipy/sparse/tests/test_base.py
Log:
added more rigorous checking of __getitem__ and __setitem__


Modified: trunk/scipy/sparse/compressed.py
===================================================================
--- trunk/scipy/sparse/compressed.py	2008-11-14 18:42:37 UTC (rev 5109)
+++ trunk/scipy/sparse/compressed.py	2008-11-15 03:45:23 UTC (rev 5110)
@@ -371,7 +371,7 @@
         if (col < 0):
             col += N
         if not (0<=row<M) or not (0<=col<N):
-            raise IndexError, "index out of bounds"
+            raise IndexError("index out of bounds")
 
         major_index, minor_index = self._swap((row,col))
 
@@ -387,7 +387,7 @@
         elif num_matches == 1:
             return self.data[start:end][indxs[0]]
         else:
-            raise ValueError,'nonzero entry (%d,%d) occurs more than once' % (row,col)
+            raise ValueError('nonzero entry (%d,%d) occurs more than once' % (row,col))
 
     def _get_slice(self, i, start, stop, stride, shape):
         """Returns a copy of the elements
@@ -493,15 +493,18 @@
                         'lil_matrix is more efficient.' % self.format, \
                         SparseEfficiencyWarning)
 
-                self.sort_indices()
+                if self.has_sorted_indices:
+                    # preserve sorted order
+                    newindx = start + self.indices[start:end].searchsorted(minor_index)
+                else:
+                    newindx = start
 
-                newindx = self.indices[start:end].searchsorted(minor_index)
-                newindx += start
+                val         = np.array([val],         dtype=self.data.dtype)
+                minor_index = np.array([minor_index], dtype=self.indices.dtype)
 
-                val = np.array([val],dtype=self.data.dtype)
-                minor_index = np.array([minor_index],dtype=self.indices.dtype)
-                self.data    = np.concatenate((self.data[:newindx],val,self.data[newindx:]))
-                self.indices = np.concatenate((self.indices[:newindx],minor_index,self.indices[newindx:]))
+                self.data    = np.concatenate((self.data[:newindx],    val,         self.data[newindx:]))
+                self.indices = np.concatenate((self.indices[:newindx], minor_index, self.indices[newindx:]))
+                self.indptr  = self.indptr.copy()
 
                 self.indptr[major_index+1:] += 1
 

Modified: trunk/scipy/sparse/dok.py
===================================================================
--- trunk/scipy/sparse/dok.py	2008-11-14 18:42:37 UTC (rev 5109)
+++ trunk/scipy/sparse/dok.py	2008-11-15 03:45:23 UTC (rev 5110)
@@ -95,12 +95,11 @@
             i, j = key
             assert isintlike(i) and isintlike(j)
         except (AssertionError, TypeError, ValueError):
-            raise IndexError, "index must be a pair of integers"
+            raise IndexError('index must be a pair of integers')
         try:
-            assert not (i < 0 or i >= self.shape[0] or j < 0 or
-                     j >= self.shape[1])
+            assert not (i < 0 or i >= self.shape[0] or j < 0 or j >= self.shape[1])
         except AssertionError:
-            raise IndexError, "index out of bounds"
+            raise IndexError('index out of bounds')
         return dict.get(self, key, default)
 
     def  __getitem__(self, key):
@@ -111,7 +110,7 @@
         try:
             i, j = key
         except (ValueError, TypeError):
-            raise TypeError, "index must be a pair of integers or slices"
+            raise TypeError('index must be a pair of integers or slices')
 
 
         # Bounds checking
@@ -119,16 +118,17 @@
             if i < 0:
                 i += self.shape[0]
             if i < 0 or i >= self.shape[0]:
-                raise IndexError, "index out of bounds"
+                raise IndexError('index out of bounds')
+
         if isintlike(j):
             if j < 0:
                 j += self.shape[1]
             if j < 0 or j >= self.shape[1]:
-                raise IndexError, "index out of bounds"
+                raise IndexError('index out of bounds')
 
         # First deal with the case where both i and j are integers
         if isintlike(i) and isintlike(j):
-            return dict.get(self, key, 0.)
+            return dict.get(self, (i,j), 0.)
         else:
             # Either i or j is a slice, sequence, or invalid.  If i is a slice
             # or sequence, unfold it first and call __getitem__ recursively.
@@ -141,7 +141,7 @@
             else:
                 # Make sure i is an integer. (But allow it to be a subclass of int).
                 if not isintlike(i):
-                    raise TypeError, "index must be a pair of integers or slices"
+                    raise TypeError('index must be a pair of integers or slices')
                 seq = None
             if seq is not None:
                 # i is a seq
@@ -151,7 +151,7 @@
                     last = seq[-1]
                     if first < 0 or first >= self.shape[0] or last < 0 \
                                  or last >= self.shape[0]:
-                        raise IndexError, "index out of bounds"
+                        raise IndexError('index out of bounds')
                     newshape = (last-first+1, 1)
                     new = dok_matrix(newshape)
                     # ** This uses linear time in the size m of dimension 0:

Modified: trunk/scipy/sparse/lil.py
===================================================================
--- trunk/scipy/sparse/lil.py	2008-11-14 18:42:37 UTC (rev 5109)
+++ trunk/scipy/sparse/lil.py	2008-11-15 03:45:23 UTC (rev 5110)
@@ -155,14 +155,19 @@
         return new
 
     def _get1(self, i, j):
-        row = self.rows[i]
-        data = self.data[i]
+        
+        if i < 0:
+            i += self.shape[0]
+        if i < 0 or i >= self.shape[0]:
+            raise IndexError('row index out of bounds')
 
         if j < 0:
             j += self.shape[1]
-
-        if j < 0 or j > self.shape[1]:
+        if j < 0 or j >= self.shape[1]:
             raise IndexError('column index out of bounds')
+        
+        row  = self.rows[i]
+        data = self.data[i]
 
         pos = bisect_left(row, j)
         if pos != len(data) and row[pos] == j:

Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py	2008-11-14 18:42:37 UTC (rev 5109)
+++ trunk/scipy/sparse/tests/test_base.py	2008-11-15 03:45:23 UTC (rev 5110)
@@ -573,20 +573,34 @@
 
 class _TestGetSet:
     def test_setelement(self):
-        a = self.spmatrix((3,4))
-        a[1,2] = 4.0
-        a[0,1] = 3
-        a[2,0] = 2.0
-        a[0,-1] = 8
-        a[-1,-2] = 7
-        assert_array_equal(a.todense(),[[0,3,0,8],[0,0,4,0],[2,0,7,0]])
+        A = self.spmatrix((3,4))
+        A[ 1, 2] = 4.0
+        A[ 0, 1] = 3
+        A[ 2, 0] = 2.0
+        A[ 0,-1] = 8
+        A[-1,-2] = 7
+        A[ 0, 1] = 5
+        assert_array_equal(A.todense(),[[0,5,0,8],[0,0,4,0],[2,0,7,0]])
+        
+        for ij in [(0,4),(-1,4),(3,0),(3,4),(3,-1)]:
+            assert_raises(IndexError, A.__setitem__, ij, 123.0)
 
     def test_getelement(self):
-        assert_equal(self.datsp[0,0],1)
-        assert_equal(self.datsp[0,1],0)
-        assert_equal(self.datsp[1,0],3)
-        assert_equal(self.datsp[2,1],2)
+        D = array([[1,0,0],
+                   [4,3,0],
+                   [0,2,0],
+                   [0,0,0]])
+        A = self.spmatrix(D)
 
+        M,N = D.shape
+
+        for i in range(-M, M):
+            for j in range(-N, N):
+                assert_equal(A[i,j], D[i,j])
+         
+        for ij in [(0,3),(-1,3),(4,0),(4,3),(4,-1)]:
+            assert_raises(IndexError, A.__getitem__, ij)
+
 class _TestSolve:
     def test_solve(self):
         """ Test whether the lu_solve command segfaults, as reported by Nils




More information about the Scipy-svn mailing list