[Scipy-svn] r6643 - in trunk/scipy/integrate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Jul 28 18:50:46 EDT 2010
Author: ptvirtan
Date: 2010-07-28 17:50:46 -0500 (Wed, 28 Jul 2010)
New Revision: 6643
Modified:
trunk/scipy/integrate/quadrature.py
trunk/scipy/integrate/tests/test_quadrature.py
Log:
BUG: integrate: ensure evaluates the function at least once; also, add a relative tolerance in stopping condition
Also, use warnings instead of printing to stdout in quadrature
Modified: trunk/scipy/integrate/quadrature.py
===================================================================
--- trunk/scipy/integrate/quadrature.py 2010-07-28 13:57:25 UTC (rev 6642)
+++ trunk/scipy/integrate/quadrature.py 2010-07-28 22:50:46 UTC (rev 6643)
@@ -8,7 +8,11 @@
asarray, real, trapz, arange, empty
import numpy as np
import math
+import warnings
+class AccuracyWarning(Warning):
+ pass
+
def fixed_quad(func,a,b,args=(),n=5):
"""
Compute a definite integral using fixed-order Gaussian quadrature.
@@ -99,7 +103,8 @@
return output
return vfunc
-def quadrature(func,a,b,args=(),tol=1.49e-8,maxiter=50, vec_func=True):
+def quadrature(func, a, b, args=(), tol=1.49e-8, rtol=1.49e-8, maxiter=50,
+ vec_func=True):
"""
Compute a definite integral using fixed-tolerance Gaussian quadrature.
@@ -116,9 +121,9 @@
Upper limit of integration.
args : tuple, optional
Extra arguments to pass to function.
- tol : float, optional
+ tol, rol : float, optional
Iteration stops when error between last two iterates is less than
- tolerance.
+ `tol` OR the relative change is less than `rtol`.
maxiter : int, optional
Maximum number of iterations.
vec_func : bool, optional
@@ -147,17 +152,20 @@
odeint: ODE integrator
"""
- err = 100.0
- val = err
- n = 1
vfunc = vectorize1(func, args, vec_func=vec_func)
- while (err > tol) and (n < maxiter):
+ val = np.inf
+ err = np.inf
+ for n in xrange(1, maxiter+1):
newval = fixed_quad(vfunc, a, b, (), n)[0]
err = abs(newval-val)
val = newval
- n = n + 1
- if n == maxiter:
- print "maxiter (%d) exceeded. Latest difference = %e" % (n,err)
+
+ if err < tol or err < rtol*abs(val):
+ break
+ else:
+ warnings.warn(
+ "maxiter (%d) exceeded. Latest difference = %e" % (maxiter, err),
+ AccuracyWarning)
return val, err
def tupleset(t, i, value):
Modified: trunk/scipy/integrate/tests/test_quadrature.py
===================================================================
--- trunk/scipy/integrate/tests/test_quadrature.py 2010-07-28 13:57:25 UTC (rev 6642)
+++ trunk/scipy/integrate/tests/test_quadrature.py 2010-07-28 22:50:46 UTC (rev 6643)
@@ -1,8 +1,7 @@
-
import numpy
from numpy import cos, sin, pi
from numpy.testing import TestCase, run_module_suite, assert_equal, \
- assert_almost_equal
+ assert_almost_equal, assert_allclose
from scipy.integrate import quadrature, romberg, romb, newton_cotes
@@ -18,6 +17,13 @@
table_val = 0.30614353532540296487
assert_almost_equal(val, table_val, decimal=7)
+ def test_quadrature_rtol(self):
+ def myfunc(x,n,z): # Bessel function integrand
+ return 1e90 * cos(n*x-z*sin(x))/pi
+ val, err = quadrature(myfunc,0,pi,(2,1.8),rtol=1e-10)
+ table_val = 1e90 * 0.30614353532540296487
+ assert_allclose(val, table_val, rtol=1e-10)
+
def test_romberg(self):
# Typical function with two extra arguments:
def myfunc(x, n, z): # Bessel function integrand
More information about the Scipy-svn
mailing list