[Scipy-svn] r3485 - trunk/scipy/sparse

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Nov 2 00:18:07 EDT 2007


Author: wnbell
Date: 2007-11-01 23:18:04 -0500 (Thu, 01 Nov 2007)
New Revision: 3485

Modified:
   trunk/scipy/sparse/sparse.py
Log:
refactoring of csr_matrix and csc_matrix constructors


Modified: trunk/scipy/sparse/sparse.py
===================================================================
--- trunk/scipy/sparse/sparse.py	2007-11-02 00:23:31 UTC (rev 3484)
+++ trunk/scipy/sparse/sparse.py	2007-11-02 04:18:04 UTC (rev 3485)
@@ -506,6 +506,17 @@
 
 class _cs_matrix(spmatrix):
     """base matrix class for compressed row and column oriented matrices"""
+    def _set_self(self, other, copy=False):
+        if copy:
+            other = other.copy()
+        
+        self.data    = other.data
+        self.indices = other.indices
+        self.indptr  = other.indptr
+        self.shape   = other.shape
+        self.dtype   = other.data.dtype
+          
+
     def _check_format(self, orientation, full_check):
         # some functions pass floats
         self.shape = tuple([int(x) for x in self.shape])
@@ -926,32 +937,12 @@
             else:
                 raise ValueError, "dense array must have rank 1 or 2"
         elif isspmatrix(arg1):
-            s = arg1
-            self.dtype = getdtype(dtype, s)
-            if isinstance(s, csc_matrix):
-                # do nothing but copy information
-                self.shape = s.shape
-                if copy:
-                    self.data = s.data.copy()
-                    self.indices = s.indices.copy()
-                    self.indptr = s.indptr.copy()
-                else:
-                    self.data = s.data
-                    self.indices = s.indices
-                    self.indptr = s.indptr
-            elif isinstance(s, csr_matrix):
-                self.shape = s.shape
-                self.indptr, self.indices, self.data = csrtocsc(s.shape[0],
-                                                               s.shape[1],
-                                                               s.indptr,
-                                                               s.indices,
-                                                               s.data)
+            try:
+                other = arg1.tocsc(copy=copy)
+            except AttributeError:
+                raise AttributeError,'all sparse matrices must have .tocsc()'
             else:
-                temp = s.tocsc()
-                self.data = temp.data
-                self.indices = temp.indices
-                self.indptr = temp.indptr
-                self.shape = temp.shape
+                self._set_self( other )
         elif isinstance(arg1, tuple):
             if isshape(arg1):
                 self.dtype = getdtype(dtype, default=float)
@@ -964,34 +955,30 @@
             else:
                 try:
                     # Try interpreting it as (data, ij)
-                    (s, ij) = arg1
+                    (data, ij) = arg1
                     assert isinstance(ij, ndarray) and (rank(ij) == 2) \
-                            and (shape(ij) == (2, len(s)))
+                            and (shape(ij) == (2, len(data)))
                 except (AssertionError, TypeError, ValueError):
                     try:
-                        # Try interpreting it as (data, rowind, indptr)
-                        (s, rowind, indptr) = arg1
-                        self.dtype = getdtype(dtype, s)
+                        # Try interpreting it as (data, indices, indptr)
+                        (data, indices, indptr) = arg1
+                        self.dtype = getdtype(dtype, data)
                         if copy:
-                            self.data = array(s)
-                            self.indices = array(rowind)
-                            self.indptr = array(indptr)
+                            self.data    = array(data)
+                            self.indices = array(indices)
+                            self.indptr  = array(indptr)
                         else:
-                            self.data = asarray(s)
-                            self.indices = asarray(rowind)
-                            self.indptr = asarray(indptr)
+                            self.data    = asarray(data)
+                            self.indices = asarray(indices)
+                            self.indptr  = asarray(indptr)
                     except:
                         raise ValueError, "unrecognized form for csc_matrix constructor"
                 else:
                     # (data, ij) format
-                    self.dtype = getdtype(dtype, s)
+                    self.dtype = getdtype(dtype, data)
                     ijnew = array(ij, copy=copy)
-                    temp = coo_matrix((s, ijnew), dims=dims, \
-                                      dtype=self.dtype).tocsc()
-                    self.shape = temp.shape
-                    self.data = temp.data
-                    self.indices = temp.indices
-                    self.indptr = temp.indptr
+                    self._set_self( coo_matrix((data, ijnew), dims=dims, \
+                                         dtype=self.dtype).tocsc() )
         else:
             raise ValueError, "unrecognized form for csc_matrix constructor"
 
@@ -1238,6 +1225,7 @@
     """
     def __init__(self, arg1, dims=None, nzmax=NZMAX, dtype=None, copy=False):
         _cs_matrix.__init__(self)
+
         if isdense(arg1):
             self.dtype = getdtype(dtype, arg1)
             # Convert the dense array or matrix arg1 to CSR format
@@ -1251,28 +1239,12 @@
             else:
                 raise ValueError, "dense array must have rank 1 or 2"
         elif isspmatrix(arg1):
-            s = arg1
-            self.dtype = getdtype(dtype, s)
-            if isinstance(s, csr_matrix):
-                # do nothing but copy information
-                self.shape = s.shape
-                if copy:
-                    self.data = s.data.copy()
-                    self.indices = s.indices.copy()
-                    self.indptr = s.indptr.copy()
-                else:
-                    self.data = s.data
-                    self.indices = s.indices
-                    self.indptr = s.indptr
+            try:
+                other = arg1.tocsr(copy=copy)
+            except AttributeError:
+                raise AttributeError,'all sparse matrices must have .tocsr()'
             else:
-                try:
-                    temp = s.tocsr()
-                except AttributeError:
-                    temp = csr_matrix(s.tocsc())
-                self.data = temp.data
-                self.indices = temp.indices
-                self.indptr = temp.indptr
-                self.shape = temp.shape
+                self._set_self( other )
         elif isinstance(arg1, tuple):
             if isshape(arg1):
                 # It's a tuple of matrix dimensions (M, N)
@@ -1285,35 +1257,31 @@
             else:
                 try:
                     # Try interpreting it as (data, ij)
-                    (s, ij) = arg1
+                    (data, ij) = arg1
                     assert isinstance(ij, ndarray) and (rank(ij) == 2) \
-                           and (shape(ij) == (2, len(s)))
+                           and (shape(ij) == (2, len(data)))
                 except (AssertionError, TypeError, ValueError, AttributeError):
                     try:
-                        # Try interpreting it as (data, colind, indptr)
-                        (s, colind, indptr) = arg1
+                        # Try interpreting it as (data, indices, indptr)
+                        (data, indices, indptr) = arg1
                     except (TypeError, ValueError):
                         raise ValueError, "unrecognized form for csr_matrix constructor"
                     else:
-                        self.dtype = getdtype(dtype, s)
+                        self.dtype = getdtype(dtype, data)
                         if copy:
-                            self.data = array(s, dtype=self.dtype)
-                            self.indices = array(colind)
+                            self.data    = array(data, dtype=self.dtype)
+                            self.indices = array(indices)
                             self.indptr  = array(indptr)
                         else:
-                            self.data = asarray(s, dtype=self.dtype)
-                            self.indices = asarray(colind)
-                            self.indptr = asarray(indptr)
+                            self.data    = asarray(data, dtype=self.dtype)
+                            self.indices = asarray(indices)
+                            self.indptr  = asarray(indptr)
                 else:
                     # (data, ij) format
-                    self.dtype = getdtype(dtype, s)
+                    self.dtype = getdtype(dtype, data)
                     ijnew = array(ij, copy=copy)
-                    temp = coo_matrix((s, ijnew), dims=dims, \
-                                      dtype=self.dtype).tocsr()
-                    self.shape = temp.shape
-                    self.data = temp.data
-                    self.indices = temp.indices
-                    self.indptr = temp.indptr
+                    self._set_self( coo_matrix((data, ijnew), dims=dims, \
+                                      dtype=self.dtype).tocsr() )
         else:
             raise ValueError, "unrecognized form for csr_matrix constructor"
 




More information about the Scipy-svn mailing list