[Scipy-svn] r4679 - in trunk/scipy/sparse/linalg/isolve: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Fri Aug 29 19:27:19 EDT 2008
Author: wnbell
Date: 2008-08-29 18:27:17 -0500 (Fri, 29 Aug 2008)
New Revision: 4679
Modified:
trunk/scipy/sparse/linalg/isolve/iterative.py
trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
Log:
(mostly) fixes ticket #728
Modified: trunk/scipy/sparse/linalg/isolve/iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/iterative.py 2008-08-29 07:14:57 UTC (rev 4678)
+++ trunk/scipy/sparse/linalg/isolve/iterative.py 2008-08-29 23:27:17 UTC (rev 4679)
@@ -400,8 +400,8 @@
have a typecode attribute use xtype=0 for the same type as
b or use xtype='f','d','F',or 'D'
callback -- an optional user-supplied function to call after each
- iteration. It is called as callback(xk), where xk is the
- current parameter vector.
+ iteration. It is called as callback(rk), where rk is the
+ the current relative residual
"""
A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
@@ -427,31 +427,53 @@
ftflag = True
bnrm2 = -1.0
iter_ = maxiter
+ old_ijob = ijob
+ first_pass = True
+ resid_ready = False
+ iter_num = 1
while True:
olditer = iter_
x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
revcom(b, x, restrt, work, work2, iter_, resid, info, ndx1, ndx2, ijob)
- if callback is not None and iter_ > olditer:
- callback(x)
+ #if callback is not None and iter_ > olditer:
+ # callback(x)
slice1 = slice(ndx1-1, ndx1-1+n)
slice2 = slice(ndx2-1, ndx2-1+n)
- if (ijob == -1):
+ if (ijob == -1): # gmres success, update last residual
+ if resid_ready and callback is not None:
+ callback(resid)
+ resid_ready = False
+
break
elif (ijob == 1):
work[slice2] *= sclr2
work[slice2] += sclr1*matvec(x)
elif (ijob == 2):
work[slice1] = psolve(work[slice2])
+ if not first_pass and old_ijob==3:
+ resid_ready = True
+
+ first_pass = False
elif (ijob == 3):
work[slice2] *= sclr2
work[slice2] += sclr1*matvec(work[slice1])
+ if resid_ready and callback is not None:
+ callback(resid)
+ resid_ready = False
+ iter_num = iter_num+1
+
elif (ijob == 4):
if ftflag:
info = -1
ftflag = False
bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
+
+ old_ijob = ijob
ijob = 2
+ if iter_num > maxiter:
+ break
+
return postprocess(x), info
Modified: trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py
===================================================================
--- trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py 2008-08-29 07:14:57 UTC (rev 4678)
+++ trunk/scipy/sparse/linalg/isolve/tests/test_iterative.py 2008-08-29 23:27:17 UTC (rev 4679)
@@ -4,10 +4,10 @@
from numpy.testing import *
-from numpy import zeros, dot, diag, ones, arange, array
+from numpy import zeros, dot, diag, ones, arange, array, abs, max
from numpy.random import rand
from scipy.linalg import norm
-from scipy.sparse import spdiags
+from scipy.sparse import spdiags, csr_matrix
from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres
@@ -21,11 +21,12 @@
#TODO test complex matrices
#TODO test both preconditioner methods
-data = ones((3,10))
+N = 40
+data = ones((3,N))
data[0,:] = 2
data[1,:] = -1
data[2,:] = -1
-Poisson1D = spdiags( data, [0,-1,1], 10, 10, format='csr')
+Poisson1D = spdiags( data, [0,-1,1], N, N, format='csr')
data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
RandDiag = spdiags( data, [0], 10, 10, format='csr' )
@@ -61,9 +62,28 @@
#data[1,:] = -1
#A = spdiags( data, [0,-1], 10, 10, format='csr')
#self.cases.append( (A,False,True) )
+
+ def test_maxiter(self):
+ """test whether maxiter is respected"""
+ A = Poisson1D
+ for solver,req_sym,req_pos in self.solvers:
+ b = arange(A.shape[0], dtype=float)
+ x0 = 0*b
+ residuals = []
+ def callback(x):
+ residuals.append( norm(b - A*x) )
+
+ x, info = solver(A, b, x0=x0, tol=1e-8, maxiter=3, callback=callback)
+
+ assert(len(residuals) in [2,3])
+
+ # TODO enforce this condition instead!
+ #assert_equal(len(residuals), 2)
+
+
def test_convergence(self):
"""test whether all methods converge"""
@@ -149,5 +169,23 @@
assert( norm(b - A*x) < 1e-8*norm(b) )
+class TestGMRES(TestCase):
+ def test_callback(self):
+
+ def store_residual(r, rvec):
+ rvec[rvec.nonzero()[0].max()+1] = r
+
+ #Define, A,b
+ A = csr_matrix(array([[-2,1,0,0,0,0],[1,-2,1,0,0,0],[0,1,-2,1,0,0],[0,0,1,-2,1,0],[0,0,0,1,-2,1],[0,0,0,0,1,-2]]))
+ b = ones((A.shape[0],))
+ maxiter=1
+ rvec = zeros(maxiter+1)
+ rvec[0] = 1.0
+ callback = lambda r:store_residual(r, rvec)
+ x,flag = gmres(A, b, x0=zeros(A.shape[0]), tol=1e-16, maxiter=maxiter, callback=callback)
+ diff = max(abs((rvec - array([1.0, 0.81649658092772603]))))
+ assert(diff < 1e-5)
+
+
if __name__ == "__main__":
nose.run(argv=['', __file__])
More information about the Scipy-svn
mailing list