[Scipy-svn] r5477 - in trunk/scipy/special: cephes tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Jan 17 14:46:43 EST 2009


Author: ptvirtan
Date: 2009-01-17 13:45:37 -0600 (Sat, 17 Jan 2009)
New Revision: 5477

Modified:
   trunk/scipy/special/cephes/hyperg.c
   trunk/scipy/special/tests/test_basic.py
Log:
Fix #640: bad output from scipy.special.iv(k,z) for large real z.

This occurred because Cephes hyperg.c:hyp1f1p estimated the error
for power series expansion of Kummer confluent 1F1 as zero when
the number of terms was too large. This commit makes Cephes to
assume 100% error when the series does not converge, which makes
it to switch to an asymptotic expansion.

The Cephes hyp1f1 is not exposed by scipy.special (it uses Specfun's
implementation), but it's used internally by Cephes.

Improved tests for Iv are included.

Modified: trunk/scipy/special/cephes/hyperg.c
===================================================================
--- trunk/scipy/special/cephes/hyperg.c	2009-01-17 11:10:10 UTC (rev 5476)
+++ trunk/scipy/special/cephes/hyperg.c	2009-01-17 19:45:37 UTC (rev 5477)
@@ -104,7 +104,6 @@
 
 asum = hy1f1a( a, b, x, &acanc );
 
-
 /* Pick the result with less estimated error */
 
 if( acanc < pcanc )
@@ -157,7 +156,11 @@
 	if( an == 0 )			/* a singularity		*/
 		return( sum );
 	if( n > 200 )
-		goto pdone;
+                {
+                /* did not converge: estimate 100% error */
+                *err = 1.0;
+                return sum;
+                }
 	u = x * ( an / (bn * n) );
 
 	/* check for blowup */

Modified: trunk/scipy/special/tests/test_basic.py
===================================================================
--- trunk/scipy/special/tests/test_basic.py	2009-01-17 11:10:10 UTC (rev 5476)
+++ trunk/scipy/special/tests/test_basic.py	2009-01-17 19:45:37 UTC (rev 5477)
@@ -26,7 +26,16 @@
 from numpy.testing import *
 from scipy.special import *
 import scipy.special._cephes as cephes
+import numpy as np
 
+def assert_tol_equal(a, b, rtol=1e-7, atol=0, err_msg='', verbose=True):
+    """Assert that `a` and `b` are equal to tolerance ``atol + rtol*abs(b)``"""
+    def compare(x, y):
+        return allclose(x, y, rtol=rtol, atol=atol)
+    a, b = asanyarray(a), asanyarray(b)
+    header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
+    np.testing.utils.assert_array_compare(compare, a, b, err_msg=str(err_msg),
+                                          verbose=verbose, header=header)
 
 class TestCephes(TestCase):
     def test_airy(self):
@@ -1326,51 +1335,6 @@
         assert_array_almost_equal(hypu,hprl,12)
 
 class TestBessel(TestCase):
-    def test_i0(self):
-        values = [[0.0, 1.0],
-                  [1e-10, 1.0],
-                  [0.1, 0.9071009258],
-                  [0.5, 0.6450352706],
-                  [1.0, 0.4657596077],
-                  [2.5, 0.2700464416],
-                  [5.0, 0.1835408126],
-                  [20.0, 0.0897803119],
-                 ]
-        for i, (x, v) in enumerate(values):
-            cv = i0(x) * exp(-x)
-            assert_almost_equal(cv, v, 8, err_msg='test #%d' % i)
-
-    def test_i0e(self):
-        oize = i0e(.1)
-        oizer = ive(0,.1)
-        assert_almost_equal(oize,oizer,8)
-
-    def test_i1(self):
-        values = [[0.0, 0.0],
-                  [1e-10, 0.4999999999500000e-10],
-                  [0.1, 0.0452984468],
-                  [0.5, 0.1564208032],
-                  [1.0, 0.2079104154],
-                  [5.0, 0.1639722669],
-                  [20.0, 0.0875062222],
-                 ]
-        for i, (x, v) in enumerate(values):
-            cv = i1(x) * exp(-x)
-            assert_almost_equal(cv, v, 8, err_msg='test #%d' % i)
-
-    def test_i1e(self):
-        oi1e = i1e(.1)
-        oi1er = ive(1,.1)
-        assert_almost_equal(oi1e,oi1er,8)
-
-    def test_iti0k0(self):
-        iti0 = array(iti0k0(5))
-        assert_array_almost_equal(iti0,array([31.848667776169801, 1.5673873907283657]),5)
-
-    def test_it2i0k0(self):
-        it2k = it2i0k0(.1)
-        assert_array_almost_equal(it2k,array([0.0012503906973464409, 3.3309450354686687]),6)
-
     def test_itj0y0(self):
         it0 = array(itj0y0(.2))
         assert_array_almost_equal(it0,array([0.19933433254006822, -0.34570883800412566]),8)
@@ -1382,26 +1346,6 @@
     def test_negv(self):
         assert_equal(iv(3,2), iv(-3,2))
 
-    def test_iv(self):
-        iv1 = iv(0,.1)*exp(-.1)
-        assert_almost_equal(iv1,0.90710092578230106,10)
-
-    def test_negv(self):
-        assert_equal(ive(3,2), ive(-3,2))
-
-    def test_ive(self):
-        ive1 = ive(0,.1)
-        iv1 = iv(0,.1)*exp(-.1)
-        assert_almost_equal(ive1,iv1,10)
-
-    def test_ivp0(self):
-        assert_almost_equal(iv(1,2), ivp(0,2), 10)
-
-    def test_ivp(self):
-        y=(iv(0,2)+iv(2,2))/2
-        x = ivp(1,2)
-        assert_almost_equal(x,y,10)
-
     def test_j0(self):
         oz = j0(.1)
         ozr = jn(0,.1)
@@ -1645,7 +1589,98 @@
         yvp1 = yvp(2,.2)
         assert_array_almost_equal(yvp1,yvpr,10)
 
+class TestBesselI(object):
 
+    def _series(self, v, z, n=200):
+        k = arange(0, n).astype(float_)
+        r = (v+2*k)*log(.5*z) - log(gamma(k+1)) - log(gamma(v+k+1))
+        r[isnan(r)] = inf
+        r = exp(r)
+        err = abs(r).max() * finfo(float_).eps * n + abs(r[-1])*10
+        return r.sum(), err
+
+    def test_i0_series(self):
+        for z in [1., 10., 200.5]:
+            value, err = self._series(0, z)
+            assert_tol_equal(i0(z), value, atol=err, err_msg=z)
+
+    def test_i1_series(self):
+            for z in [1., 10., 200.5]:
+                value, err = self._series(1, z)
+                assert_tol_equal(i1(z), value, atol=err, err_msg=z)
+
+    def test_iv_series(self):
+        for v in [-20., -10., -1., 0., 1., 12.49, 120.]:
+            for z in [1., 10., 200.5, -1+2j]:
+                value, err = self._series(v, z)
+                assert_tol_equal(iv(v, z), value, atol=err, err_msg=(v, z))
+
+    def test_i0(self):
+        values = [[0.0, 1.0],
+                  [1e-10, 1.0],
+                  [0.1, 0.9071009258],
+                  [0.5, 0.6450352706],
+                  [1.0, 0.4657596077],
+                  [2.5, 0.2700464416],
+                  [5.0, 0.1835408126],
+                  [20.0, 0.0897803119],
+                 ]
+        for i, (x, v) in enumerate(values):
+            cv = i0(x) * exp(-x)
+            assert_almost_equal(cv, v, 8, err_msg='test #%d' % i)
+
+    def test_i0e(self):
+        oize = i0e(.1)
+        oizer = ive(0,.1)
+        assert_almost_equal(oize,oizer,8)
+
+    def test_i1(self):
+        values = [[0.0, 0.0],
+                  [1e-10, 0.4999999999500000e-10],
+                  [0.1, 0.0452984468],
+                  [0.5, 0.1564208032],
+                  [1.0, 0.2079104154],
+                  [5.0, 0.1639722669],
+                  [20.0, 0.0875062222],
+                 ]
+        for i, (x, v) in enumerate(values):
+            cv = i1(x) * exp(-x)
+            assert_almost_equal(cv, v, 8, err_msg='test #%d' % i)
+
+    def test_i1e(self):
+        oi1e = i1e(.1)
+        oi1er = ive(1,.1)
+        assert_almost_equal(oi1e,oi1er,8)
+
+    def test_iti0k0(self):
+        iti0 = array(iti0k0(5))
+        assert_array_almost_equal(iti0,array([31.848667776169801, 1.5673873907283657]),5)
+
+    def test_it2i0k0(self):
+        it2k = it2i0k0(.1)
+        assert_array_almost_equal(it2k,array([0.0012503906973464409, 3.3309450354686687]),6)
+
+    def test_iv(self):
+        iv1 = iv(0,.1)*exp(-.1)
+        assert_almost_equal(iv1,0.90710092578230106,10)
+
+    def test_negv(self):
+        assert_equal(ive(3,2), ive(-3,2))
+
+    def test_ive(self):
+        ive1 = ive(0,.1)
+        iv1 = iv(0,.1)*exp(-.1)
+        assert_almost_equal(ive1,iv1,10)
+
+    def test_ivp0(self):
+        assert_almost_equal(iv(1,2), ivp(0,2), 10)
+
+    def test_ivp(self):
+        y=(iv(0,2)+iv(2,2))/2
+        x = ivp(1,2)
+        assert_almost_equal(x,y,10)
+    
+
 class TestLaguerre(TestCase):
     def test_laguerre(self):
         lag0 = laguerre(0)
@@ -2017,13 +2052,13 @@
         for v in [-20, -10, -7.99, -3.4, -1, 0, 1, 3.4, 12.49, 16]:
             for z in [1, 10, 19, 21, 30]:
                 value, err = self._series(v, z)
-                assert allclose(struve(v, z), value, atol=err), (v, z)
+                assert_tol_equal(struve(v, z), value, rtol=0, atol=err), (v, z)
 
     def test_some_values(self):
-        assert_almost_equal(struve(-7.99, 21), 0.0467547614113, decimal=8)
-        assert_almost_equal(struve(-8.01, 21), 0.0398716951023, decimal=9)
-        assert_almost_equal(struve(-3.0, 200), 0.0142134427432, decimal=13)
-        assert_almost_equal(struve(-8.0, -41), 0.0192469727846, decimal=9)
+        assert_tol_equal(struve(-7.99, 21), 0.0467547614113, rtol=1e-7)
+        assert_tol_equal(struve(-8.01, 21), 0.0398716951023, rtol=1e-8)
+        assert_tol_equal(struve(-3.0, 200), 0.0142134427432, rtol=1e-12)
+        assert_tol_equal(struve(-8.0, -41), 0.0192469727846, rtol=1e-11)
         assert_equal(struve(-12, -41), -struve(-12, 41))
         assert_equal(struve(+12, -41), -struve(+12, 41))
         assert_equal(struve(-11, -41), +struve(-11, 41))
@@ -2034,9 +2069,9 @@
 
     def test_regression_679(self):
         """Regression test for #679"""
-        assert_almost_equal(struve(-1.0, 20 - 1e-8), struve(-1.0, 20 + 1e-8))
-        assert_almost_equal(struve(-2.0, 20 - 1e-8), struve(-2.0, 20 + 1e-8))
-        assert_almost_equal(struve(-4.3, 20 - 1e-8), struve(-4.3, 20 + 1e-8))
+        assert_tol_equal(struve(-1.0, 20 - 1e-8), struve(-1.0, 20 + 1e-8))
+        assert_tol_equal(struve(-2.0, 20 - 1e-8), struve(-2.0, 20 + 1e-8))
+        assert_tol_equal(struve(-4.3, 20 - 1e-8), struve(-4.3, 20 + 1e-8))
 
 if __name__ == "__main__":
     run_module_suite()




More information about the Scipy-svn mailing list