[Scipy-svn] r6269 - trunk/scipy/sparse/linalg/eigen/arpack

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Mar 26 01:34:58 EDT 2010


Author: cdavid
Date: 2010-03-26 00:34:57 -0500 (Fri, 26 Mar 2010)
New Revision: 6269

Modified:
   trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
Log:
REF: use _ArpackParams in symmetric solver.

Modified: trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
===================================================================
--- trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:34:48 UTC (rev 6268)
+++ trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:34:57 UTC (rev 6269)
@@ -84,15 +84,15 @@
         if ncv > n or ncv < k:
             raise ValueError("ncv must be k<=ncv<=n, ncv=%s" % ncv)
 
-        if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
-            raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
-
         ltr = _type_conv[tp]
 
         self.v = np.zeros((n, ncv), tp) # holds Ritz vectors
         self.rwork = None # Only used for unsymmetric, complex solver
 
         if mode == "unsymmetric":
+            if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
+                raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
+
             self.workd = np.zeros(3 * n, tp)
             self.workl = np.zeros(3 * ncv * ncv + 6 * ncv, tp)
             self.solver = _arpack.__dict__[ltr + 'naupd']
@@ -103,6 +103,9 @@
 
             self.ipntr = np.zeros(14, "int")
         elif mode == "symmetric":
+            if not which in ['LM','SM','LA','SA','BE']:
+                raise ValueError("which must be one of %s" % ' '.join(whiches))
+
             self.workd = np.zeros(3 * n, tp)
             self.workl = np.zeros(ncv * (ncv + 8), tp)
             self.solver = _arpack.__dict__[ltr + 'saupd']
@@ -413,107 +416,54 @@
         raise ValueError('expected square matrix (shape=%s)' % shape)
     n = A.shape[0]
 
-    # guess type
-    typ = A.dtype.char
-    if typ not in 'fd':
-        raise ValueError("matrix must be real valued (type must be 'f' or 'd')")
-
     if M is not None:
         raise NotImplementedError("generalized eigenproblem not supported yet")
-    if sigma is not None:
-        raise NotImplementedError("shifted eigenproblem not supported yet")
 
-    if ncv is None:
-        ncv=2*k+1
-    ncv=min(ncv,n)
-    if maxiter==None:
-        maxiter=n*10
-    # assign starting vector
-    if v0 is not None:
-        resid=v0
-        info=1
-    else:
-        resid = np.zeros(n,typ)
-        info=0
+    params = _ArpackParams(n, k, A.dtype.char, "symmetric", sigma,
+                           ncv, v0, maxiter, which, tol)
 
-    # some sanity checks
-    if k <= 0:
-        raise ValueError("k must be positive, k=%d"%k)
-    if k == n:
-        raise ValueError("k must be less than rank(A), k=%d"%k)
-    if maxiter <= 0:
-        raise ValueError("maxiter must be positive, maxiter=%d"%maxiter)
-    whiches=['LM','SM','LA','SA','BE']
-    if which not in whiches:
-        raise ValueError("which must be one of %s"%' '.join(whiches))
-    if ncv > n or ncv < k:
-        raise ValueError("ncv must be k<=ncv<=n, ncv=%s"%ncv)
-
-    # assign solver and postprocessor
-    ltr = _type_conv[typ]
-    eigsolver = _arpack.__dict__[ltr+'saupd']
-    eigextract = _arpack.__dict__[ltr+'seupd']
-
-    # set output arrays, parameters, and workspace
-    v = np.zeros((n,ncv),typ)
-    workd = np.zeros(3*n,typ)
-    workl = np.zeros(ncv*(ncv+8),typ)
-    iparam = np.zeros(11,'int')
-    ipntr = np.zeros(11,'int')
     ido = 0
-
-    # set solver mode and parameters
-    # only supported mode is 1: Ax=lx
-    ishfts = 1
-    mode1 = 1
-    bmat='I'
-    iparam[0] = ishfts
-    iparam[2] = maxiter
-    iparam[6] = mode1
-
     while True:
-        ido,resid,v,iparam,ipntr,info =\
-            eigsolver(ido,bmat,which,k,tol,resid,v,
-                      iparam,ipntr,workd,workl,info)
+        ido, params.resid, params.v, params.iparam, params.ipntr, params.info = \
+            params.solver(ido, params.bmat, params.which, params.k, params.tol,
+                    params.resid, params.v, params.iparam, params.ipntr,
+                    params.workd, params.workl, params.info)
 
-        xslice = slice(ipntr[0]-1, ipntr[0]-1+n)
-        yslice = slice(ipntr[1]-1, ipntr[1]-1+n)
+        xslice = slice(params.ipntr[0]-1, params.ipntr[0]-1+n)
+        yslice = slice(params.ipntr[1]-1, params.ipntr[1]-1+n)
         if ido == -1:
             # initialization
-            workd[yslice]=A.matvec(workd[xslice])
+            params.workd[yslice] = A.matvec(params.workd[xslice])
         elif ido == 1:
             # compute y=Ax
-            workd[yslice]=A.matvec(workd[xslice])
+            params.workd[yslice] = A.matvec(params.workd[xslice])
         else:
             break
 
-    if info < -1 :
-        raise RuntimeError("Error info=%d in arpack" % info)
-        return None
+    if params.info < -1 :
+        raise RuntimeError("Error info=%d in arpack" % params.info)
+    elif params.info == 1:
+        warnings.warn("Maximum number of iterations taken: %s" % params.iparam[2])
 
-    if info == 1:
-        warnings.warn("Maximum number of iterations taken: %s" % iparam[2])
+    if params.iparam[4] < k:
+        warnings.warn("Only %d/%d eigenvectors converged" % (params.iparam[4], k))
 
-    if iparam[4] < k:
-        warnings.warn("Only %d/%d eigenvectors converged" % (iparam[4], k))
-
     # now extract eigenvalues and (optionally) eigenvectors
     rvec = return_eigenvectors
     ierr = 0
     howmny = 'A' # return all eigenvectors
-    sselect = np.zeros(ncv,'int') # unused
+    sselect = np.zeros(params.ncv, 'int') # unused
     sigma = 0.0 # no shifts, not implemented
 
-    d,z,info =\
-             eigextract(rvec,howmny,sselect,sigma,
-                        bmat,which, k,tol,resid,v,iparam[0:7],ipntr,
-                        workd[0:2*n],workl,ierr)
+    d, z, info = params.extract(rvec, howmny, sselect, sigma, params.bmat,
+            params.which, k, params.tol, params.resid, params.v,
+            params.iparam[0:7], params.ipntr, params.workd[0:2*n],
+            params.workl,ierr)
 
     if ierr != 0:
-        raise RuntimeError("Error info=%d in arpack"%info)
-        return None
+        raise RuntimeError("Error info=%d in arpack" % params.info)
     if return_eigenvectors:
-        return d,z
+        return d, z
     return d
 
 def svd(A, k=6):




More information about the Scipy-svn mailing list