[Scipy-svn] r5034 - in trunk/scipy/sparse/linalg/isolve: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Nov 9 22:06:01 EST 2008


Author: wnbell
Date: 2008-11-09 21:05:57 -0600 (Sun, 09 Nov 2008)
New Revision: 5034

Modified:
   trunk/scipy/sparse/linalg/isolve/iterative.py
   trunk/scipy/sparse/linalg/isolve/minres.py
   trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
Log:
test 'info' output of iterative methods
make GMRES use default restart of 10 (as opposed to n)
resolves ticket #666 (the mark of the beast!)


Modified: trunk/scipy/sparse/linalg/isolve/iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/iterative.py	2008-11-10 01:57:55 UTC (rev 5033)
+++ trunk/scipy/sparse/linalg/isolve/iterative.py	2008-11-10 03:05:57 UTC (rev 5034)
@@ -1,14 +1,5 @@
-## Automatically adapted for scipy Oct 18, 2005 by
+"""Iterative methods for solving linear systems"""
 
-
-# Iterative methods using reverse-communication raw material
-#   These methods solve
-#   Ax = b  for x
-
-#   where A must have A.matvec(x,*args) defined
-#    or be a numeric array
-
-
 __all__ = ['bicg','bicgstab','cg','cgs','gmres','qmr']
 
 import _iterative
@@ -106,6 +97,10 @@
                 ftflag = False
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
+    
+    if info > 0 and iter_ == maxiter and resid > tol:
+        #info isn't set appropriately otherwise
+        info = iter_
 
     return postprocess(x), info
 
@@ -197,6 +192,10 @@
                 ftflag = False
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
+    
+    if info > 0 and iter_ == maxiter and resid > tol:
+        #info isn't set appropriately otherwise
+        info = iter_
 
     return postprocess(x), info
 
@@ -284,6 +283,11 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
+
+    if info > 0 and iter_ == maxiter and resid > tol:
+        #info isn't set appropriately otherwise
+        info = iter_
+
     return postprocess(x), info
 
 
@@ -369,11 +373,15 @@
                 ftflag = False
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
+    
+    if info > 0 and iter_ == maxiter and resid > tol:
+        #info isn't set appropriately otherwise
+        info = iter_
 
     return postprocess(x), info
 
 
-def gmres(A, b, x0=None, tol=1e-5, restrt=None, maxiter=None, xtype=None, M=None, callback=None):
+def gmres(A, b, x0=None, tol=1e-5, restrt=20, maxiter=None, xtype=None, M=None, callback=None):
     """Use Generalized Minimal RESidual iteration to solve A x = b
 
     Inputs:
@@ -397,7 +405,7 @@
 
     x0  -- (0) default starting guess.
     tol -- (1e-5) relative tolerance to achieve
-    restrt -- (n) When to restart (change this to get faster performance -- but
+    restrt -- (10) When to restart (change this to get faster performance -- but
                    may not converge).
     maxiter -- (10*n) maximum number of iterations
     xtype  --  The type of the result.  If None, then it will be
@@ -416,14 +424,14 @@
     if maxiter is None:
         maxiter = n*10
 
+    restrt = min(restrt, n)        
+
     matvec = A.matvec
     psolve = M.matvec
     ltr = _type_conv[x.dtype.char]
     revcom   = getattr(_iterative, ltr + 'gmresrevcom')
     stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-    if restrt is None:
-        restrt = n
     resid = tol
     ndx1 = 1
     ndx2 = -1
@@ -480,6 +488,10 @@
 
         if iter_num > maxiter:
             break
+    
+    if info >= 0 and resid > tol:
+        #info isn't set appropriately otherwise
+        info = maxiter
 
     return postprocess(x), info
 
@@ -593,6 +605,10 @@
                 ftflag = False
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
+    
+    if info > 0 and iter_ == maxiter and resid > tol:
+        #info isn't set appropriately otherwise
+        info = iter_
 
     return postprocess(x), info
 

Modified: trunk/scipy/sparse/linalg/isolve/minres.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/minres.py	2008-11-10 01:57:55 UTC (rev 5033)
+++ trunk/scipy/sparse/linalg/isolve/minres.py	2008-11-10 03:05:57 UTC (rev 5034)
@@ -67,7 +67,7 @@
     istop = 0;   itn   = 0;   Anorm = 0;    Acond = 0;
     rnorm = 0;   ynorm = 0;
 
-    xtype = A.dtype #TODO update
+    xtype = x.dtype
 
     eps = finfo(xtype).eps
 
@@ -273,9 +273,14 @@
         print last + ' Arnorm  =  %12.4e'                       %  (Arnorm,)
         print last + msg[istop+1]
 
-    return (postprocess(x),0)
+    if istop == 6:
+        info = maxiter
+    else:
+        info = 0
 
+    return (postprocess(x),info)
 
+
 if __name__ == '__main__':
     from scipy import ones, arange
     from scipy.linalg import norm

Modified: trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py	2008-11-10 01:57:55 UTC (rev 5033)
+++ trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py	2008-11-10 03:05:57 UTC (rev 5034)
@@ -8,6 +8,7 @@
 from scipy.linalg import norm
 from scipy.sparse import spdiags, csr_matrix
 
+from scipy.sparse.linalg.interface import LinearOperator
 from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres
 
 #def callback(x):
@@ -66,6 +67,7 @@
         """test whether maxiter is respected"""
 
         A = Poisson1D
+        tol = 1e-12
 
         for solver,req_sym,req_pos in self.solvers:
             b  = arange(A.shape[0], dtype=float)
@@ -75,11 +77,11 @@
             def callback(x):
                 residuals.append( norm(b - A*x) )
 
-            x, info = solver(A, b, x0=x0, tol=1e-8, maxiter=3, callback=callback)
+            x, info = solver(A, b, x0=x0, tol=tol, maxiter=3, callback=callback)
 
             assert_equal(len(residuals), 3)
+            assert_equal(info, 3)
 
-
     def test_convergence(self):
         """test whether all methods converge"""
 
@@ -101,33 +103,44 @@
                 assert( norm(b - A*x) < tol*norm(b) )
 
     def test_precond(self):
-        """test whether all methods accept a preconditioner"""
+        """test whether all methods accept a trivial preconditioner"""
 
         tol = 1e-8
+        
+        def identity(b,which=None):
+            """trivial preconditioner"""
+            return b
 
         for solver,req_sym,req_pos in self.solvers:
+
             for A,sym,pos in self.cases:
                 if req_sym and not sym: continue
                 if req_pos and not pos: continue
 
                 M,N = A.shape
-                D = spdiags( [abs(1.0/A.diagonal())], [0], M, N)
-                def precond(b,which=None):
-                    return D*b
+                D = spdiags( [1.0/A.diagonal()], [0], M, N)
 
-                A = A.copy()
-                A.psolve  = precond
-                A.rpsolve = precond
-
                 b  = arange(A.shape[0], dtype=float)
                 x0 = 0*b
 
-                x, info = solver(A, b, x0=x0, tol=tol)
+                precond = LinearOperator(A.shape, identity, rmatvec=identity)
 
+                if solver == qmr:
+                    x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
+                else:
+                    x, info = solver(A, b, M=precond, x0=x0, tol=tol)
                 assert_equal(info,0)
                 assert( norm(b - A*x) < tol*norm(b) )
+                
+                A = A.copy()
+                A.psolve  = identity 
+                A.rpsolve = identity
 
+                x, info = solver(A, b, x0=x0, tol=tol)
+                assert_equal(info,0)
+                assert( norm(b - A*x) < tol*norm(b) )
 
+
 class TestQMR(TestCase):
     def test_leftright_precond(self):
         """Check that QMR works with left and right preconditioners"""




More information about the Scipy-svn mailing list