[Scipy-svn] r4679 - in trunk/scipy/sparse/linalg/isolve: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Aug 29 19:27:19 EDT 2008


Author: wnbell
Date: 2008-08-29 18:27:17 -0500 (Fri, 29 Aug 2008)
New Revision: 4679

Modified:
   trunk/scipy/sparse/linalg/isolve/iterative.py
   trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
Log:
(mostly) fixes ticket #728


Modified: trunk/scipy/sparse/linalg/isolve/iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/iterative.py	2008-08-29 07:14:57 UTC (rev 4678)
+++ trunk/scipy/sparse/linalg/isolve/iterative.py	2008-08-29 23:27:17 UTC (rev 4679)
@@ -400,8 +400,8 @@
                have a typecode attribute use xtype=0 for the same type as
                b or use xtype='f','d','F',or 'D'
     callback -- an optional user-supplied function to call after each
-                iteration.  It is called as callback(xk), where xk is the
-                current parameter vector.
+                iteration.  It is called as callback(rk), where rk is the
+                the current relative residual 
     """
     A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
 
@@ -427,31 +427,53 @@
     ftflag = True
     bnrm2 = -1.0
     iter_ = maxiter
+    old_ijob = ijob
+    first_pass = True
+    resid_ready = False
+    iter_num = 1
     while True:
         olditer = iter_
         x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
            revcom(b, x, restrt, work, work2, iter_, resid, info, ndx1, ndx2, ijob)
-        if callback is not None and iter_ > olditer:
-            callback(x)
+        #if callback is not None and iter_ > olditer:
+        #    callback(x)
         slice1 = slice(ndx1-1, ndx1-1+n)
         slice2 = slice(ndx2-1, ndx2-1+n)
-        if (ijob == -1):
+        if (ijob == -1): # gmres success, update last residual
+            if resid_ready and callback is not None:
+                callback(resid)
+                resid_ready = False
+            
             break
         elif (ijob == 1):
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(x)
         elif (ijob == 2):
             work[slice1] = psolve(work[slice2])
+            if not first_pass and old_ijob==3:
+                resid_ready = True
+
+            first_pass = False    
         elif (ijob == 3):
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(work[slice1])
+            if resid_ready and callback is not None:
+                callback(resid)
+                resid_ready = False
+                iter_num = iter_num+1
+
         elif (ijob == 4):
             if ftflag:
                 info = -1
                 ftflag = False
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
+        
+        old_ijob = ijob
         ijob = 2
 
+        if iter_num > maxiter:
+            break
+
     return postprocess(x), info
 
 

Modified: trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py	2008-08-29 07:14:57 UTC (rev 4678)
+++ trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py	2008-08-29 23:27:17 UTC (rev 4679)
@@ -4,10 +4,10 @@
 
 from numpy.testing import *
 
-from numpy import zeros, dot, diag, ones, arange, array
+from numpy import zeros, dot, diag, ones, arange, array, abs, max
 from numpy.random import rand
 from scipy.linalg import norm
-from scipy.sparse import spdiags
+from scipy.sparse import spdiags, csr_matrix
 
 from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres
 
@@ -21,11 +21,12 @@
 #TODO test complex matrices
 #TODO test both preconditioner methods
 
-data = ones((3,10))
+N = 40
+data = ones((3,N))
 data[0,:] =  2
 data[1,:] = -1
 data[2,:] = -1
-Poisson1D = spdiags( data, [0,-1,1], 10, 10, format='csr')
+Poisson1D = spdiags( data, [0,-1,1], N, N, format='csr')
 
 data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
 RandDiag = spdiags( data, [0], 10, 10, format='csr' )
@@ -61,9 +62,28 @@
         #data[1,:] = -1
         #A = spdiags( data, [0,-1], 10, 10, format='csr')
         #self.cases.append( (A,False,True) )
+    
+    def test_maxiter(self):
+        """test whether maxiter is respected"""
 
+        A = Poisson1D
 
+        for solver,req_sym,req_pos in self.solvers:
+            b  = arange(A.shape[0], dtype=float)
+            x0 = 0*b
 
+            residuals = []
+            def callback(x):
+                residuals.append( norm(b - A*x) )
+
+            x, info = solver(A, b, x0=x0, tol=1e-8, maxiter=3, callback=callback)
+           
+            assert(len(residuals) in [2,3])
+
+            # TODO enforce this condition instead!
+            #assert_equal(len(residuals), 2)
+
+
     def test_convergence(self):
         """test whether all methods converge"""
 
@@ -149,5 +169,23 @@
         assert( norm(b - A*x) < 1e-8*norm(b) )
 
 
+class TestGMRES(TestCase):
+    def test_callback(self):  
+        
+        def store_residual(r, rvec):
+            rvec[rvec.nonzero()[0].max()+1] = r
+
+        #Define, A,b
+        A = csr_matrix(array([[-2,1,0,0,0,0],[1,-2,1,0,0,0],[0,1,-2,1,0,0],[0,0,1,-2,1,0],[0,0,0,1,-2,1],[0,0,0,0,1,-2]]))
+        b = ones((A.shape[0],))
+        maxiter=1
+        rvec = zeros(maxiter+1)
+        rvec[0] = 1.0
+        callback = lambda r:store_residual(r, rvec)
+        x,flag = gmres(A, b, x0=zeros(A.shape[0]), tol=1e-16, maxiter=maxiter, callback=callback)
+        diff = max(abs((rvec - array([1.0,   0.81649658092772603]))))
+        assert(diff < 1e-5)
+
+
 if __name__ == "__main__":
     nose.run(argv=['', __file__])




More information about the Scipy-svn mailing list