[Scipy-svn] r6824 - in trunk/scipy/optimize: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Sep 25 22:48:28 EDT 2010


Author: warren.weckesser
Date: 2010-09-25 21:48:28 -0500 (Sat, 25 Sep 2010)
New Revision: 6824

Modified:
   trunk/scipy/optimize/minpack.py
   trunk/scipy/optimize/tests/test_minpack.py
Log:
BUG: optimize: AttributeError in check_func() in minpack.py (ticket #1287)

Modified: trunk/scipy/optimize/minpack.py
===================================================================
--- trunk/scipy/optimize/minpack.py	2010-09-26 00:56:47 UTC (rev 6823)
+++ trunk/scipy/optimize/minpack.py	2010-09-26 02:48:28 UTC (rev 6824)
@@ -9,15 +9,20 @@
 
 __all__ = ['fsolve', 'leastsq', 'fixed_point', 'curve_fit']
 
-def check_func(thefunc, x0, args, numinputs, output_shape=None):
-    res = atleast_1d(thefunc(*((x0[:numinputs],)+args)))
+def _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape=None):
+    res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
     if (output_shape is not None) and (shape(res) != output_shape):
         if (output_shape[0] != 1):
             if len(output_shape) > 1:
                 if output_shape[1] == 1:
                     return shape(res)
-            msg = "There is a mismatch between the input and output " \
-                  "shape of %s." % thefunc.func_name
+            msg = "%s: there is a mismatch between the input and output " \
+                  "shape of the '%s' argument" % (checker, argname)
+            func_name = getattr(thefunc, 'func_name', None)
+            if func_name:
+                msg += " '%s'." % func_name
+            else:
+                msg += "."
             raise TypeError(msg)
     return shape(res)
 
@@ -107,7 +112,8 @@
     x0 = array(x0, ndmin=1)
     n = len(x0)
     if type(args) != type(()): args = (args,)
-    check_func(func, x0, args, n, (n,))
+    _check_func('fsolve', 'func', func, x0, args, n, (n,))
+    #check_func(func, x0, args, n, (n,))
     Dfun = fprime
     if Dfun is None:
         if band is None:
@@ -119,7 +125,8 @@
         retval = _minpack._hybrd(func, x0, args, full_output, xtol,
                 maxfev, ml, mu, epsfcn, factor, diag)
     else:
-        check_func(Dfun,x0,args,n,(n,n))
+        _check_func('fsolve', 'fprime', Dfun, x0, args, n, (n,n))
+        # check_func(Dfun,x0,args,n,(n,n))
         if (maxfev == 0):
             maxfev = 100*(n + 1)
         retval = _minpack._hybrj(func, Dfun, x0, args, full_output,
@@ -253,7 +260,8 @@
     n = len(x0)
     if type(args) != type(()):
         args = (args,)
-    m = check_func(func, x0, args, n)[0]
+    m = _check_func('leastsq', 'func', func, x0, args, n)[0]
+    # m = check_func(func, x0, args, n)[0]
     if n > m:
         raise TypeError('Improper input: N=%s must not exceed M=%s' % (n,m))
     if Dfun is None:
@@ -263,9 +271,11 @@
                 gtol, maxfev, epsfcn, factor, diag)
     else:
         if col_deriv:
-            check_func(Dfun, x0, args, n, (n,m))
+            _check_func('leastsq', 'Dfun', Dfun, x0, args, n, (n,m))
+            # check_func(Dfun, x0, args, n, (n,m))
         else:
-            check_func(Dfun, x0, args, n, (m,n))
+            _check_func('leastsq', 'Dfun', Dfun, x0, args, n, (m,n))
+            # check_func(Dfun, x0, args, n, (m,n))
         if (maxfev == 0):
             maxfev = 100*(n + 1)
         retval = _minpack._lmder(func, Dfun, x0, args, full_output, col_deriv,

Modified: trunk/scipy/optimize/tests/test_minpack.py
===================================================================
--- trunk/scipy/optimize/tests/test_minpack.py	2010-09-26 00:56:47 UTC (rev 6823)
+++ trunk/scipy/optimize/tests/test_minpack.py	2010-09-26 02:48:28 UTC (rev 6824)
@@ -3,7 +3,7 @@
 """
 
 from numpy.testing import assert_, assert_almost_equal, assert_array_equal, \
-        assert_array_almost_equal, TestCase, run_module_suite
+        assert_array_almost_equal, TestCase, run_module_suite, assert_raises
 import numpy as np
 from numpy import array, float64
 
@@ -11,6 +11,25 @@
 from scipy.optimize.minpack import leastsq, curve_fit, fixed_point
 
 
+class ReturnShape(object):
+    """This class exists to create a callable that does not have a 'func_name' attribute.
+    
+    __init__ takes the argument 'shape', which should be a tuple of ints.  When an instance
+    it called with a single argument 'x', it returns numpy.ones(shape). 
+    """
+    def __init__(self, shape):
+        self.shape = shape
+
+    def __call__(self, x):
+        return np.ones(self.shape)
+
+def dummy_func(x, shape):
+    """A function that returns an array of ones of the given shape.
+    `x` is ignored.
+    """
+    return np.ones(shape)
+
+
 class TestFSolve(object):
     def pressure_network(self, flow_rates, Qtot, k):
         """Evaluate non-linear equation system representing
@@ -83,7 +102,32 @@
             fprime=self.pressure_network_jacobian)
         assert_array_almost_equal(final_flows, np.ones(4))
 
+    def test_wrong_shape_func_callable(self):
+        """The callable 'func' has no 'func_name' attribute."""
+        func = ReturnShape(1)
+        # x0 is a list of two elements, but func will return an array with
+        # length 1, so this should result in a TypeError.
+        x0 = [1.5, 2.0]
+        assert_raises(TypeError, optimize.fsolve, func, x0)
 
+    def test_wrong_shape_func_function(self):
+        # x0 is a list of two elements, but func will return an array with
+        # length 1, so this should result in a TypeError.
+        x0 = [1.5, 2.0]
+        assert_raises(TypeError, optimize.fsolve, dummy_func, x0, args=((1,),))
+
+    def test_wrong_shape_fprime_callable(self):
+        """The callables 'func' and 'deriv_func' have no 'func_name' attribute."""
+        func = ReturnShape(1)
+        deriv_func = ReturnShape((2,2))
+        assert_raises(TypeError, optimize.fsolve, func, x0=[0,1], fprime=deriv_func)
+
+    def test_wrong_shape_fprime_function(self):
+        func = lambda x: dummy_func(x, (2,))
+        deriv_func = lambda x: dummy_func(x, (3,3))
+        assert_raises(TypeError, optimize.fsolve, func, x0=[0,1], fprime=deriv_func)
+
+
 class TestLeastSq(TestCase):
     def setUp(self):
         x = np.linspace(0, 10, 40)
@@ -125,7 +169,32 @@
         assert_(ier in (1,2,3,4), 'solution not found: %s'%mesg)
         assert_array_equal(p0, p0_copy)
 
+    def test_wrong_shape_func_callable(self):
+        """The callable 'func' has no 'func_name' attribute."""
+        func = ReturnShape(1)
+        # x0 is a list of two elements, but func will return an array with
+        # length 1, so this should result in a TypeError.
+        x0 = [1.5, 2.0]
+        assert_raises(TypeError, optimize.leastsq, func, x0)
 
+    def test_wrong_shape_func_function(self):
+        # x0 is a list of two elements, but func will return an array with
+        # length 1, so this should result in a TypeError.
+        x0 = [1.5, 2.0]
+        assert_raises(TypeError, optimize.leastsq, dummy_func, x0, args=((1,),))
+
+    def test_wrong_shape_Dfun_callable(self):
+        """The callables 'func' and 'deriv_func' have no 'func_name' attribute."""
+        func = ReturnShape(1)
+        deriv_func = ReturnShape((2,2))
+        assert_raises(TypeError, optimize.leastsq, func, x0=[0,1], Dfun=deriv_func)
+
+    def test_wrong_shape_Dfun_function(self):
+        func = lambda x: dummy_func(x, (2,))
+        deriv_func = lambda x: dummy_func(x, (3,3))
+        assert_raises(TypeError, optimize.leastsq, func, x0=[0,1], Dfun=deriv_func)
+
+
 class TestCurveFit(TestCase):
     def setUp(self):
         self.y = array([1.0, 3.2, 9.5, 13.7])




More information about the Scipy-svn mailing list