[Scipy-svn] r4840 - in trunk/scipy/sparse: . linalg tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Oct 25 21:16:06 EDT 2008


Author: wnbell
Date: 2008-10-25 20:16:01 -0500 (Sat, 25 Oct 2008)
New Revision: 4840

Modified:
   trunk/scipy/sparse/base.py
   trunk/scipy/sparse/compressed.py
   trunk/scipy/sparse/coo.py
   trunk/scipy/sparse/dia.py
   trunk/scipy/sparse/dok.py
   trunk/scipy/sparse/lil.py
   trunk/scipy/sparse/linalg/interface.py
   trunk/scipy/sparse/tests/test_base.py
Log:
cleaned up sp_matrix imports
cleaned up multiplication handlers


Modified: trunk/scipy/sparse/base.py
===================================================================
--- trunk/scipy/sparse/base.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/base.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -5,9 +5,7 @@
 
 from warnings import warn
 
-import numpy
-from numpy import asarray, asmatrix, asanyarray, ones, deprecate, ravel, \
-        matrix
+import numpy as np
 
 from sputils import isdense, isscalarlike, isintlike
 
@@ -100,7 +98,7 @@
             return self
         else:
             for fp_type in fp_types:
-                if self.dtype <= numpy.dtype(fp_type):
+                if self.dtype <= np.dtype(fp_type):
                     return self.astype(fp_type)
 
             raise TypeError,'cannot upcast [%s] to a floating \
@@ -137,15 +135,15 @@
             format = 'und'
         return format
 
-    @deprecate
+    @np.deprecate
     def rowcol(self, num):
         return (None, None)
 
-    @deprecate
+    @np.deprecate
     def getdata(self, num):
         return None
 
-    @deprecate
+    @np.deprecate
     def listprint(self, start, stop):
         """Provides a way to print over a single index.
         """
@@ -212,10 +210,12 @@
         else:
             return getattr(self,'to' + format)()
 
-    # default operations use the CSR format as a base
-    #   and operations return in csr format
-    #  thus, a new sparse matrix format just needs to define
-    #  a tocsr method
+    ###################################################################
+    #  NOTE: All arithmetic operations use csr_matrix by default.  
+    # Therefore a new sparse matrix format just needs to define a 
+    # .tocsr() method to provide arithmetic support.  Any of these
+    # methods can be overridden for efficiency.
+    ####################################################################
 
     def multiply(self, other):
         """Point-wise multiplication by another matrix
@@ -239,13 +239,34 @@
         return self.tocsr().__rsub__(other)
 
     # old __mul__ interfaces
-    def matvec(self, other):
+    @np.deprecate
+    def matvec(self,other):
         return self * other
-    def matmat(self, other):
+
+    @np.deprecate
+    def matmat(self,other):
         return self * other
+
+    @np.deprecate
     def dot(self, other):
         return self * other
 
+    @np.deprecate
+    def rmatvec(self, other, conjugate=True):
+        """Multiplies the vector 'other' by the sparse matrix, returning a
+        dense vector as a result.
+
+        If 'conjugate' is True:
+            - returns A.transpose().conj() * other
+        Otherwise:
+            - returns A.transpose() * other.
+
+        """
+        if conjugate:
+            return self.conj().transpose() * other
+        else:
+            return self.transpose() * other
+
     def __mul__(self, other):
         """interpret other and call one of the following
 
@@ -268,18 +289,18 @@
             other.shape
         except AttributeError:
             # If it's a list or whatever, treat it like a matrix
-            other = asanyarray(other)
+            other = np.asanyarray(other)
 
-        if isdense(other) and asarray(other).squeeze().ndim <= 1:
+        if isdense(other) and np.asarray(other).squeeze().ndim <= 1:
             ##
             # dense row or column vector
             if other.shape != (N,) and other.shape != (N,1):
                 raise ValueError('dimension mismatch')
 
-            result = self._mul_vector(ravel(other))
+            result = self._mul_vector(np.ravel(other))
 
-            if isinstance(other, matrix):
-                result = asmatrix(result)
+            if isinstance(other, np.matrix):
+                result = np.asmatrix(result)
 
             if other.ndim == 2 and other.shape[1] == 1:
                 # If 'other' was an (nx1) column vector, reshape the result
@@ -294,10 +315,10 @@
             if other.shape[0] != self.shape[1]:
                 raise ValueError('dimension mismatch')
 
-            result = self._mul_dense_matrix(asarray(other))
+            result = self._mul_dense_matrix(np.asarray(other))
 
-            if isinstance(other, matrix):
-                result = asmatrix(result)
+            if isinstance(other, np.matrix):
+                result = np.asmatrix(result)
 
             return result
         else:
@@ -316,9 +337,21 @@
     def _mul_sparse_matrix(self, other):
         return self.tocsr()._mul_sparse_matrix(other)
 
-    def __rmul__(self, other):
-        return self.tocsr().__rmul__(other)
+    def __rmul__(self, other): # other * self
+        if isscalarlike(other):
+            return self.__mul__(other)
+        else:
+            # Don't use asarray unless we have to
+            try:
+                tr = other.transpose()
+            except AttributeError:
+                tr = np.asarray(other).transpose()
+            return (self.transpose() * tr).transpose()
 
+    ####################
+    # Other Arithmetic #
+    ####################
+
     def __truediv__(self, other):
         if isscalarlike(other):
             return self * (1./other)
@@ -349,12 +382,12 @@
 
     def __pow__(self, other):
         if self.shape[0] != self.shape[1]:
-            raise TypeError,'matrix is not square'
+            raise TypeError('matrix is not square')
 
         if isintlike(other):
             other = int(other)
             if other < 0:
-                raise ValueError,'exponent must be >= 0'
+                raise ValueError('exponent must be >= 0')
 
             if other == 0:
                 from construct import identity
@@ -367,7 +400,7 @@
                     result = result*self
                 return result
         elif isscalarlike(other):
-            raise ValueError,'exponent must be an integer'
+            raise ValueError('exponent must be an integer')
         elif isspmatrix(other):
             warn('Using ** for elementwise multiplication is deprecated.'\
                     'Use .multiply() instead', DeprecationWarning)
@@ -460,36 +493,11 @@
         a[0, i] = 1
         return a * self
 
-
-    def rmatvec(self, other, conjugate=True):
-        """Multiplies the vector 'other' by the sparse matrix, returning a
-        dense vector as a result.
-
-        If 'conjugate' is True:
-            - returns A.transpose().conj() * other
-        Otherwise:
-            - returns A.transpose() * other.
-
-        """
-        return self.tocsr().rmatvec(other, conjugate=conjugate)
-
-    #def rmatmat(self, other, conjugate=True):
-    #    """ If 'conjugate' is True:
-    #        returns other * A.transpose().conj(),
-    #    where 'other' is a matrix.  Otherwise:
-    #        returns other * A.transpose().
-    #    """
-    #    other = csc_matrix(other)
-    #    if conjugate:
-    #        return other.matmat(self.transpose()).conj()
-    #    else:
-    #        return other.matmat(self.transpose())
-
     #def __array__(self):
     #    return self.toarray()
 
     def todense(self):
-        return asmatrix(self.toarray())
+        return np.asmatrix(self.toarray())
 
     def toarray(self):
         return self.tocoo().toarray()
@@ -522,13 +530,13 @@
         m, n = self.shape
         if axis == 0:
             # sum over columns
-            return asmatrix(ones((1, m), dtype=self.dtype)) * self
+            return np.asmatrix(np.ones((1, m), dtype=self.dtype)) * self
         elif axis == 1:
             # sum over rows
-            return self * asmatrix(ones((n, 1), dtype=self.dtype))
+            return self * np.asmatrix(np.ones((n, 1), dtype=self.dtype))
         elif axis is None:
             # sum over rows and columns
-            return ( self * asmatrix(ones((n, 1), dtype=self.dtype)) ).sum()
+            return ( self * np.asmatrix(np.ones((n, 1), dtype=self.dtype)) ).sum()
         else:
             raise ValueError, "axis out of bounds"
 

Modified: trunk/scipy/sparse/compressed.py
===================================================================
--- trunk/scipy/sparse/compressed.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/compressed.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -5,8 +5,7 @@
 
 from warnings import warn
 
-from numpy import array, asarray, zeros, rank, intc, empty, isscalar, \
-                  empty_like, where, concatenate, deprecate, diff, multiply
+import numpy as np
 
 from base import spmatrix, isspmatrix, SparseEfficiencyWarning
 from data import _data_matrix
@@ -15,7 +14,6 @@
         isscalarlike, isintlike
 
 
-
 class _cs_matrix(_data_matrix):
     """base matrix class for compressed row and column oriented matrices"""
 
@@ -43,9 +41,9 @@
                 # create empty matrix
                 self.shape = arg1   #spmatrix checks for errors here
                 M, N = self.shape
-                self.data    = zeros(0, getdtype(dtype, default=float))
-                self.indices = zeros(0, intc)
-                self.indptr  = zeros(self._swap((M,N))[0] + 1, dtype=intc)
+                self.data    = np.zeros(0, getdtype(dtype, default=float))
+                self.indices = np.zeros(0, np.intc)
+                self.indptr  = np.zeros(self._swap((M,N))[0] + 1, dtype=np.intc)
             else:
                 if len(arg1) == 2:
                     # (data, ij) format
@@ -55,9 +53,9 @@
                 elif len(arg1) == 3:
                     # (data, indices, indptr) format
                     (data, indices, indptr) = arg1
-                    self.indices = array(indices, copy=copy)
-                    self.indptr  = array(indptr, copy=copy)
-                    self.data    = array(data, copy=copy, dtype=getdtype(dtype, data))
+                    self.indices = np.array(indices, copy=copy)
+                    self.indptr  = np.array(indptr, copy=copy)
+                    self.data    = np.array(data, copy=copy, dtype=getdtype(dtype, data))
                 else:
                     raise ValueError, "unrecognized %s_matrix constructor usage" %\
                             self.format
@@ -65,7 +63,7 @@
         else:
             #must be dense
             try:
-                arg1 = asarray(arg1)
+                arg1 = np.asarray(arg1)
             except:
                 raise ValueError, "unrecognized %s_matrix constructor usage" % \
                         self.format
@@ -128,14 +126,13 @@
                     % self.indices.dtype.name )
 
         # only support 32-bit ints for now
-        self.indptr  = asarray(self.indptr,dtype=intc)
-        self.indices = asarray(self.indices,dtype=intc)
+        self.indptr  = np.asarray(self.indptr,  dtype=np.intc)
+        self.indices = np.asarray(self.indices, dtype=np.intc)
         self.data    = to_native(self.data)
 
         # check array shapes
-        if (rank(self.data) != 1) or (rank(self.indices) != 1) or \
-           (rank(self.indptr) != 1):
-            raise ValueError,"data, indices, and indptr should be rank 1"
+        if np.rank(self.data) != 1 or np.rank(self.indices) != 1 or np.rank(self.indptr) != 1:
+            raise ValueError('data, indices, and indptr should be rank 1')
 
         # check index pointer
         if (len(self.indptr) != major_dim + 1 ):
@@ -164,7 +161,7 @@
                 if self.indices.min() < 0:
                     raise ValueError, "%s index values must be >= 0" % \
                             minor_name
-                if diff(self.indptr).min() < 0:
+                if np.diff(self.indptr).min() < 0:
                     raise ValueError,'index pointer values must form a " \
                                         "non-decreasing sequence'
 
@@ -225,18 +222,6 @@
             raise NotImplementedError
 
 
-    def __rmul__(self, other): # other * self
-        if isscalarlike(other):
-            return self.__mul__(other)
-        else:
-            # Don't use asarray unless we have to
-            try:
-                tr = other.transpose()
-            except AttributeError:
-                tr = asarray(other).transpose()
-            return (self.transpose() * tr).transpose()
-
-
     def __truediv__(self,other):
         if isscalarlike(other):
             return self * (1./other)
@@ -258,7 +243,7 @@
             raise ValueError('inconsistent shapes')
 
         if isdense(other):
-            return multiply(self.todense(),other)
+            return np.multiply(self.todense(),other)
         else:
             other = self.__class__(other)
             return self._binopt(other,'_elmul_')
@@ -272,7 +257,7 @@
         M,N = self.shape
 
         #output array
-        result = zeros( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
+        result = np.zeros( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
 
         # csr_matvec or csc_matvec
         fn = getattr(sparsetools,self.format + '_matvec')
@@ -285,7 +270,7 @@
         M,N = self.shape
         n_vecs = other.shape[1] #number of column vectors
 
-        result = zeros( (M,n_vecs), dtype=upcast(self.dtype,other.dtype) )
+        result = np.zeros( (M,n_vecs), dtype=upcast(self.dtype,other.dtype) )
 
         # csr_matvecs or csc_matvecs
         fn = getattr(sparsetools,self.format + '_matvecs')
@@ -299,7 +284,7 @@
         K2, N = other.shape
 
         major_axis = self._swap((M,N))[0]
-        indptr = empty( major_axis + 1, dtype=intc )
+        indptr = np.empty(major_axis + 1, dtype=np.intc)
 
         other = self.__class__(other) #convert to this format
         fn = getattr(sparsetools, self.format + '_matmat_pass1')
@@ -308,8 +293,8 @@
                   indptr)
 
         nnz = indptr[-1]
-        indices = empty( nnz, dtype=intc)
-        data    = empty( nnz, dtype=upcast(self.dtype,other.dtype))
+        indices = np.empty(nnz, dtype=np.intc)
+        data    = np.empty(nnz, dtype=upcast(self.dtype,other.dtype))
 
         fn = getattr(sparsetools, self.format + '_matmat_pass2')
         fn( M, N, self.indptr, self.indices, self.data, \
@@ -318,125 +303,8 @@
 
         return self.__class__((data,indices,indptr),shape=(M,N))
 
-    def matvec(self,other):
-        return self * other
 
-    def matmat(self,other):
-        return self * other
-
-    #def matmat(self, other):
-    #    if isspmatrix(other):
-    #        M, K1 = self.shape
-    #        K2, N = other.shape
-    #        if (K1 != K2):
-    #            raise ValueError, "shape mismatch error"
-
-    #        #return self._binopt(other,'mu',in_shape=(M,N),out_shape=(M,N))
-
-    #        major_axis = self._swap((M,N))[0]
-    #        indptr = empty( major_axis + 1, dtype=intc )
-
-    #        other = self.__class__(other) #convert to this format
-    #        fn = getattr(sparsetools, self.format + '_matmat_pass1')
-    #        fn( M, N, self.indptr, self.indices, \
-    #                  other.indptr, other.indices, \
-    #                  indptr)
-
-    #        nnz = indptr[-1]
-    #        indices = empty( nnz, dtype=intc)
-    #        data    = empty( nnz, dtype=upcast(self.dtype,other.dtype))
-
-    #        fn = getattr(sparsetools, self.format + '_matmat_pass2')
-    #        fn( M, N, self.indptr, self.indices, self.data, \
-    #                  other.indptr, other.indices, other.data, \
-    #                  indptr, indices, data)
-
-    #        return self.__class__((data,indices,indptr),shape=(M,N))
-
-
-    #    elif isdense(other):
-    #        # TODO make sparse * dense matrix multiplication more efficient
-    #
-    #        # matvec each column of other
-    #        result = hstack( [ self * col.reshape(-1,1) for col in asarray(other).T ] )
-    #        if isinstance(other, matrix):
-    #            result = asmatrix(result)
-    #        return result
-
-    #    else:
-    #        raise TypeError, "need a dense or sparse matrix"
-
-
-    #def matvec(self, other):
-    #    """Sparse matrix vector product (self * other)
-
-    #    'other' may be a rank 1 array of length N or a rank 2 array
-    #    or matrix with shape (N,1).
-
-    #    """
-    #    #If the optional 'output' parameter is defined, it will
-    #    #be used to store the result.  Otherwise, a new vector
-    #    #will be allocated.
-
-    #    if isdense(other):
-    #        M,N = self.shape
-
-    #        if other.shape != (N,) and other.shape != (N,1):
-    #            raise ValueError, "dimension mismatch"
-
-    #        # csrmux, cscmux
-    #        fn = getattr(sparsetools,self.format + '_matvec')
-
-    #        #output array
-    #        y = zeros( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
-
-    #        #if output is None:
-    #        #    y = empty( self.shape[0], dtype=upcast(self.dtype,other.dtype) )
-    #        #else:
-    #        #    if output.shape != (M,) and output.shape != (M,1):
-    #        #        raise ValueError, "output array has improper dimensions"
-    #        #    if not output.flags.c_contiguous:
-    #        #        raise ValueError, "output array must be contiguous"
-    #        #    if output.dtype != upcast(self.dtype,other.dtype):
-    #        #        raise ValueError, "output array has dtype=%s "\
-    #        #                "dtype=%s is required" % \
-    #        #                (output.dtype,upcast(self.dtype,other.dtype))
-    #        #    y = output
-
-    #        fn(self.shape[0], self.shape[1], \
-    #            self.indptr, self.indices, self.data, numpy.ravel(other), y)
-
-    #        if isinstance(other, matrix):
-    #            y = asmatrix(y)
-
-    #        if other.ndim == 2 and other.shape[1] == 1:
-    #            # If 'other' was an (nx1) column vector, reshape the result
-    #            y = y.reshape(-1,1)
-
-    #        return y
-
-    #    elif isspmatrix(other):
-    #        raise TypeError, "use matmat() for sparse * sparse"
-
-    #    else:
-    #        raise TypeError, "need a dense vector"
-
-    def rmatvec(self, other, conjugate=True):
-        """Multiplies the vector 'other' by the sparse matrix, returning a
-        dense vector as a result.
-
-        If 'conjugate' is True:
-            - returns A.transpose().conj() * other
-        Otherwise:
-            - returns A.transpose() * other.
-
-        """
-        if conjugate:
-            return self.transpose().conj().matvec( other )
-        else:
-            return self.transpose().matvec( other )
-
-    @deprecate
+    @np.deprecate
     def getdata(self, ind):
         return self.data[ind]
 
@@ -445,7 +313,7 @@
         """
         #TODO support k-th diagonal
         fn = getattr(sparsetools, self.format + "_diagonal")
-        y = empty( min(self.shape), dtype=upcast(self.dtype) )
+        y = np.empty( min(self.shape), dtype=upcast(self.dtype) )
         fn(self.shape[0], self.shape[1], self.indptr, self.indices, self.data, y)
         return y
 
@@ -506,7 +374,7 @@
 
         start = self.indptr[major_index]
         end   = self.indptr[major_index+1]
-        indxs = where(minor_index == self.indices[start:end])[0]
+        indxs = np.where(minor_index == self.indices[start:end])[0]
 
         num_matches = len(indxs)
 
@@ -539,7 +407,7 @@
 
         index  = self.indices[indices] - start
         data   = self.data[indices]
-        indptr = array([0, len(indices)])
+        indptr = np.array([0, len(indices)])
         return self.__class__((data, index, indptr), shape=shape, \
                               dtype=self.dtype)
 
@@ -563,7 +431,7 @@
 
                 return i0, i1
 
-            elif isscalar( sl ):
+            elif np.isscalar( sl ):
                 if sl < 0:
                     sl += num
 
@@ -612,7 +480,7 @@
 
             start = self.indptr[major_index]
             end   = self.indptr[major_index+1]
-            indxs = where(minor_index == self.indices[start:end])[0]
+            indxs = np.where(minor_index == self.indices[start:end])[0]
 
             num_matches = len(indxs)
 
@@ -627,10 +495,10 @@
                 newindx = self.indices[start:end].searchsorted(minor_index)
                 newindx += start
 
-                val = array([val],dtype=self.data.dtype)
-                minor_index = array([minor_index],dtype=self.indices.dtype)
-                self.data    = concatenate((self.data[:newindx],val,self.data[newindx:]))
-                self.indices = concatenate((self.indices[:newindx],minor_index,self.indices[newindx:]))
+                val = np.array([val],dtype=self.data.dtype)
+                minor_index = np.array([minor_index],dtype=self.indices.dtype)
+                self.data    = np.concatenate((self.data[:newindx],val,self.data[newindx:]))
+                self.indices = np.concatenate((self.indices[:newindx],minor_index,self.indices[newindx:]))
 
                 self.indptr[major_index+1:] += 1
 
@@ -670,7 +538,7 @@
             data = data.copy()
             minor_indices = minor_indices.copy()
 
-        major_indices = empty(len(minor_indices),dtype=intc)
+        major_indices = np.empty(len(minor_indices), dtype=np.intc)
 
         sparsetools.expandptr(major_dim,self.indptr,major_indices)
 
@@ -814,9 +682,9 @@
         fn = getattr(sparsetools, self.format + op + self.format)
 
         maxnnz = self.nnz + other.nnz
-        indptr  = empty_like(self.indptr)
-        indices = empty( maxnnz, dtype=intc )
-        data    = empty( maxnnz, dtype=upcast(self.dtype,other.dtype) )
+        indptr  = np.empty_like(self.indptr)
+        indices = np.empty(maxnnz, dtype=np.intc)
+        data    = np.empty(maxnnz, dtype=upcast(self.dtype,other.dtype))
 
         fn(in_shape[0], in_shape[1], \
                 self.indptr,  self.indices,  self.data,

Modified: trunk/scipy/sparse/coo.py
===================================================================
--- trunk/scipy/sparse/coo.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/coo.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -8,9 +8,6 @@
 
 import numpy as np
 
-#from numpy import array, asarray, empty, intc, zeros, unique, searchsorted,\
-#                  atleast_2d, rank, deprecate, hstack
-
 from sparsetools import coo_tocsr, coo_todense, coo_matvec
 from base import isspmatrix
 from data import _data_matrix

Modified: trunk/scipy/sparse/dia.py
===================================================================
--- trunk/scipy/sparse/dia.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/dia.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -2,7 +2,7 @@
 
 __docformat__ = "restructuredtext en"
 
-__all__ = ['dia_matrix','isspmatrix_dia']
+__all__ = ['dia_matrix', 'isspmatrix_dia']
 
 import numpy as np
 

Modified: trunk/scipy/sparse/dok.py
===================================================================
--- trunk/scipy/sparse/dok.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/dok.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -476,7 +476,7 @@
                     base[newkey] = self[key]
         return base, ext
 
-
+# TODO update these w/ new multiplication handlers
 #    def matvec(self, other):
 #        if isdense(other):
 #            if other.shape[0] != self.shape[1]:

Modified: trunk/scipy/sparse/lil.py
===================================================================
--- trunk/scipy/sparse/lil.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/lil.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -382,21 +382,6 @@
                 new[new_r,new_c] = self[i,j]
         return new
 
-    def __add__(self, other):
-        if np.isscalar(other) and other != 0:
-            raise ValueError("Refusing to destroy sparsity. "
-                             "Use x.todense() + c instead.")
-        else:
-            return spmatrix.__add__(self, other)
-
-    def __rmul__(self, other):          # other * self
-        if isscalarlike(other):
-            # Multiplication by a scalar is symmetric
-            return self.__mul__(other)
-        else:
-            return spmatrix.__rmul__(self, other)
-
-
     def toarray(self):
         d = np.zeros(self.shape, dtype=self.dtype)
         for i, row in enumerate(self.rows):

Modified: trunk/scipy/sparse/linalg/interface.py
===================================================================
--- trunk/scipy/sparse/linalg/interface.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/linalg/interface.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -127,8 +127,14 @@
                               matmat=matmat, dtype=A.dtype)
 
     elif isspmatrix(A):
-        return LinearOperator(A.shape, A.matvec, rmatvec=A.rmatvec,
-                              matmat=A.dot, dtype=A.dtype)
+        def matvec(v):
+            return A * v
+        def rmatvec(v):
+            return A.conj().transpose() * v
+        def matmat(V):
+            return A * V
+        return LinearOperator(A.shape, matvec, rmatvec=rmatvec,
+                              matmat=matmat, dtype=A.dtype)
 
     else:
         if hasattr(A,'shape') and hasattr(A,'matvec'):

Modified: trunk/scipy/sparse/tests/test_base.py
===================================================================
--- trunk/scipy/sparse/tests/test_base.py	2008-10-25 23:30:38 UTC (rev 4839)
+++ trunk/scipy/sparse/tests/test_base.py	2008-10-26 01:16:01 UTC (rev 4840)
@@ -355,21 +355,18 @@
         csp = bsp.tocsc()
         c = b
         assert_array_almost_equal((asp*csp).todense(), a*c)
-        assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
         assert_array_almost_equal( asp*c, a*c)
 
         assert_array_almost_equal( a*csp, a*c)
         assert_array_almost_equal( a2*csp, a*c)
         csp = bsp.tocsr()
         assert_array_almost_equal((asp*csp).todense(), a*c)
-        assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
         assert_array_almost_equal( asp*c, a*c)
 
         assert_array_almost_equal( a*csp, a*c)
         assert_array_almost_equal( a2*csp, a*c)
         csp = bsp.tocoo()
         assert_array_almost_equal((asp*csp).todense(), a*c)
-        assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
         assert_array_almost_equal( asp*c, a*c)
 
         assert_array_almost_equal( a*csp, a*c)
@@ -526,47 +523,6 @@
         assert_array_equal(self.dat/17.3,a.todense())
 
 
-class _TestMatvecOutput:
-    """test using the matvec() output parameter"""
-    def test_matvec_output(self):
-        pass  #Currently disabled
-
-#        #flat array
-#        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
-#        y = zeros(3,dtype='d')
-#
-#        self.datsp.matvec(x,y)
-#        assert_array_equal(self.datsp*x,y)
-#
-#        #column vector
-#        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
-#        x = x.reshape(4,1)
-#        y = zeros((3,1),dtype='d')
-#
-#        self.datsp.matvec(x,y)
-#        assert_array_equal(self.datsp*x,y)
-#
-#        # improper output type
-#        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
-#        y = zeros(3,dtype='i')
-#
-#        self.assertRaises( ValueError, self.datsp.matvec, x, y )
-#
-#        # improper output shape
-#        x = array([1.25, -6.5, 0.125, -3.75],dtype='d')
-#        y = zeros(2,dtype='d')
-#
-#        self.assertRaises( ValueError, self.datsp.matvec, x, y )
-#
-#        # proper upcast output type
-#        x = array([1.25, -6.5, 0.125, -3.75],dtype='complex64')
-#        x.imag = [1,2,3,4]
-#        y = zeros(3,dtype='complex128')
-#
-#        self.datsp.matvec(x,y)
-#        assert_array_equal(self.datsp*x,y)
-#        assert_equal((self.datsp*x).dtype,y.dtype)
-
 class _TestGetSet:
     def test_setelement(self):
         a = self.spmatrix((3,4))
@@ -893,7 +849,7 @@
 
 
 class TestCSR(_TestCommon, _TestGetSet, _TestSolve,
-        _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
+        _TestInplaceArithmetic, _TestArithmetic,
         _TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
         _TestFancyIndexing, TestCase):
     spmatrix = csr_matrix
@@ -990,7 +946,7 @@
 
 
 class TestCSC(_TestCommon, _TestGetSet, _TestSolve,
-        _TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
+        _TestInplaceArithmetic, _TestArithmetic, 
         _TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,
         _TestFancyIndexing, TestCase):
     spmatrix = csc_matrix
@@ -1400,8 +1356,7 @@
 
 
 
-class TestBSR(_TestCommon, _TestArithmetic, _TestInplaceArithmetic,
-        _TestMatvecOutput, TestCase):
+class TestBSR(_TestCommon, _TestArithmetic, _TestInplaceArithmetic, TestCase):
     spmatrix = bsr_matrix
 
     def test_constructor1(self):




More information about the Scipy-svn mailing list