[Scipy-svn] r6823 - in trunk/scipy/optimize: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Sep 25 20:56:47 EDT 2010


Author: warren.weckesser
Date: 2010-09-25 19:56:47 -0500 (Sat, 25 Sep 2010)
New Revision: 6823

Modified:
   trunk/scipy/optimize/minpack.py
   trunk/scipy/optimize/tests/test_minpack.py
Log:
BUG: optimize: stopping condition for the fixed_point function was missing 'abs' around relerr.

Modified: trunk/scipy/optimize/minpack.py
===================================================================
--- trunk/scipy/optimize/minpack.py	2010-09-24 07:48:32 UTC (rev 6822)
+++ trunk/scipy/optimize/minpack.py	2010-09-26 00:56:47 UTC (rev 6823)
@@ -3,7 +3,7 @@
 
 from numpy import atleast_1d, dot, take, triu, shape, eye, \
                   transpose, zeros, product, greater, array, \
-                  all, where, isscalar, asarray, inf
+                  all, where, isscalar, asarray, inf, abs
 
 error = _minpack.error
 
@@ -480,7 +480,7 @@
             d = p2 - 2.0 * p1 + p0
             p = where(d == 0, p2, p0 - (p1 - p0)*(p1 - p0) / d)
             relerr = where(p0 == 0, p, (p-p0)/p0)
-            if all(relerr < xtol):
+            if all(abs(relerr) < xtol):
                 return p
             p0 = p
     else:
@@ -497,7 +497,7 @@
                 relerr = p
             else:
                 relerr = (p - p0)/p0
-            if relerr < xtol:
+            if abs(relerr) < xtol:
                 return p
             p0 = p
     msg = "Failed to converge after %d iterations, value is %s" % (maxiter, p)

Modified: trunk/scipy/optimize/tests/test_minpack.py
===================================================================
--- trunk/scipy/optimize/tests/test_minpack.py	2010-09-24 07:48:32 UTC (rev 6822)
+++ trunk/scipy/optimize/tests/test_minpack.py	2010-09-26 00:56:47 UTC (rev 6823)
@@ -8,8 +8,9 @@
 from numpy import array, float64
 
 from scipy import optimize
-from scipy.optimize.minpack import leastsq, curve_fit
+from scipy.optimize.minpack import leastsq, curve_fit, fixed_point
 
+
 class TestFSolve(object):
     def pressure_network(self, flow_rates, Qtot, k):
         """Evaluate non-linear equation system representing
@@ -82,6 +83,7 @@
             fprime=self.pressure_network_jacobian)
         assert_array_almost_equal(final_flows, np.ones(4))
 
+
 class TestLeastSq(TestCase):
     def setUp(self):
         x = np.linspace(0, 10, 40)
@@ -123,6 +125,7 @@
         assert_(ier in (1,2,3,4), 'solution not found: %s'%mesg)
         assert_array_equal(p0, p0_copy)
 
+
 class TestCurveFit(TestCase):
     def setUp(self):
         self.y = array([1.0, 3.2, 9.5, 13.7])
@@ -147,7 +150,57 @@
         assert_array_almost_equal(pcov, [[0.0852, -0.1260],[-0.1260, 0.1912]], decimal=4)
 
 
+class TestFixedPoint(TestCase):
 
+    def text_scalar_trivial(self):
+        """f(x) = 2x; fixed point should be x=0"""
+        def func(x):
+            return 2.0*x
+        x0 = 1.0
+        x = fixed_point(func, x0)
+        assert_almost_equal(x, 0.0)
 
+    def test_scalar_basic1(self):
+        """f(x) = x**2; x0=1.05; fixed point should be x=1"""
+        def func(x):
+            return x**2
+        x0 = 1.05
+        x = fixed_point(func, x0)
+        assert_almost_equal(x, 1.0)
+
+    def test_scalar_basic2(self):
+        """f(x) = x**0.5; x0=1.05; fixed point should be x=1"""
+        def func(x):
+            return x**0.5
+        x0 = 1.05
+        x = fixed_point(func, x0)
+        assert_almost_equal(x, 1.0)
+
+    def test_array_trivial(self):
+        def func(x):
+            return 2.0*x
+        x0 = [0.3, 0.15]
+        x = fixed_point(func, x0)
+        assert_almost_equal(x, [0.0, 0.0])
+
+    def test_array_basic1(self):
+        """f(x) = c * x**2; fixed point should be x=1/c"""
+        def func(x, c):
+            return c * x**2
+        c = array([0.75, 1.0, 1.25])
+        x0 = [1.1, 1.15, 0.9]
+        x = fixed_point(func, x0, args=(c,))
+        assert_almost_equal(x, 1.0/c)
+
+    def test_array_basic2(self):
+        """f(x) = c * x**0.5; fixed point should be x=c**2"""
+        def func(x, c):
+            return c * x**0.5
+        c = array([0.75, 1.0, 1.25])
+        x0 = [0.8, 1.1, 1.1]
+        x = fixed_point(func, x0, args=(c,))
+        assert_almost_equal(x, c**2)
+
+
 if __name__ == "__main__":
     run_module_suite()




More information about the Scipy-svn mailing list