[Scipy-svn] r6824 - in trunk/scipy/optimize: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Sep 25 22:48:28 EDT 2010
Author: warren.weckesser
Date: 2010-09-25 21:48:28 -0500 (Sat, 25 Sep 2010)
New Revision: 6824
Modified:
trunk/scipy/optimize/minpack.py
trunk/scipy/optimize/tests/test_minpack.py
Log:
BUG: optimize: AttributeError in check_func() in minpack.py (ticket #1287)
Modified: trunk/scipy/optimize/minpack.py
===================================================================
--- trunk/scipy/optimize/minpack.py 2010-09-26 00:56:47 UTC (rev 6823)
+++ trunk/scipy/optimize/minpack.py 2010-09-26 02:48:28 UTC (rev 6824)
@@ -9,15 +9,20 @@
__all__ = ['fsolve', 'leastsq', 'fixed_point', 'curve_fit']
-def check_func(thefunc, x0, args, numinputs, output_shape=None):
- res = atleast_1d(thefunc(*((x0[:numinputs],)+args)))
+def _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape=None):
+ res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
if (output_shape is not None) and (shape(res) != output_shape):
if (output_shape[0] != 1):
if len(output_shape) > 1:
if output_shape[1] == 1:
return shape(res)
- msg = "There is a mismatch between the input and output " \
- "shape of %s." % thefunc.func_name
+ msg = "%s: there is a mismatch between the input and output " \
+ "shape of the '%s' argument" % (checker, argname)
+ func_name = getattr(thefunc, 'func_name', None)
+ if func_name:
+ msg += " '%s'." % func_name
+ else:
+ msg += "."
raise TypeError(msg)
return shape(res)
@@ -107,7 +112,8 @@
x0 = array(x0, ndmin=1)
n = len(x0)
if type(args) != type(()): args = (args,)
- check_func(func, x0, args, n, (n,))
+ _check_func('fsolve', 'func', func, x0, args, n, (n,))
+ #check_func(func, x0, args, n, (n,))
Dfun = fprime
if Dfun is None:
if band is None:
@@ -119,7 +125,8 @@
retval = _minpack._hybrd(func, x0, args, full_output, xtol,
maxfev, ml, mu, epsfcn, factor, diag)
else:
- check_func(Dfun,x0,args,n,(n,n))
+ _check_func('fsolve', 'fprime', Dfun, x0, args, n, (n,n))
+ # check_func(Dfun,x0,args,n,(n,n))
if (maxfev == 0):
maxfev = 100*(n + 1)
retval = _minpack._hybrj(func, Dfun, x0, args, full_output,
@@ -253,7 +260,8 @@
n = len(x0)
if type(args) != type(()):
args = (args,)
- m = check_func(func, x0, args, n)[0]
+ m = _check_func('leastsq', 'func', func, x0, args, n)[0]
+ # m = check_func(func, x0, args, n)[0]
if n > m:
raise TypeError('Improper input: N=%s must not exceed M=%s' % (n,m))
if Dfun is None:
@@ -263,9 +271,11 @@
gtol, maxfev, epsfcn, factor, diag)
else:
if col_deriv:
- check_func(Dfun, x0, args, n, (n,m))
+ _check_func('leastsq', 'Dfun', Dfun, x0, args, n, (n,m))
+ # check_func(Dfun, x0, args, n, (n,m))
else:
- check_func(Dfun, x0, args, n, (m,n))
+ _check_func('leastsq', 'Dfun', Dfun, x0, args, n, (m,n))
+ # check_func(Dfun, x0, args, n, (m,n))
if (maxfev == 0):
maxfev = 100*(n + 1)
retval = _minpack._lmder(func, Dfun, x0, args, full_output, col_deriv,
Modified: trunk/scipy/optimize/tests/test_minpack.py
===================================================================
--- trunk/scipy/optimize/tests/test_minpack.py 2010-09-26 00:56:47 UTC (rev 6823)
+++ trunk/scipy/optimize/tests/test_minpack.py 2010-09-26 02:48:28 UTC (rev 6824)
@@ -3,7 +3,7 @@
"""
from numpy.testing import assert_, assert_almost_equal, assert_array_equal, \
- assert_array_almost_equal, TestCase, run_module_suite
+ assert_array_almost_equal, TestCase, run_module_suite, assert_raises
import numpy as np
from numpy import array, float64
@@ -11,6 +11,25 @@
from scipy.optimize.minpack import leastsq, curve_fit, fixed_point
+class ReturnShape(object):
+ """This class exists to create a callable that does not have a 'func_name' attribute.
+
+ __init__ takes the argument 'shape', which should be a tuple of ints. When an instance
+ it called with a single argument 'x', it returns numpy.ones(shape).
+ """
+ def __init__(self, shape):
+ self.shape = shape
+
+ def __call__(self, x):
+ return np.ones(self.shape)
+
+def dummy_func(x, shape):
+ """A function that returns an array of ones of the given shape.
+ `x` is ignored.
+ """
+ return np.ones(shape)
+
+
class TestFSolve(object):
def pressure_network(self, flow_rates, Qtot, k):
"""Evaluate non-linear equation system representing
@@ -83,7 +102,32 @@
fprime=self.pressure_network_jacobian)
assert_array_almost_equal(final_flows, np.ones(4))
+ def test_wrong_shape_func_callable(self):
+ """The callable 'func' has no 'func_name' attribute."""
+ func = ReturnShape(1)
+ # x0 is a list of two elements, but func will return an array with
+ # length 1, so this should result in a TypeError.
+ x0 = [1.5, 2.0]
+ assert_raises(TypeError, optimize.fsolve, func, x0)
+ def test_wrong_shape_func_function(self):
+ # x0 is a list of two elements, but func will return an array with
+ # length 1, so this should result in a TypeError.
+ x0 = [1.5, 2.0]
+ assert_raises(TypeError, optimize.fsolve, dummy_func, x0, args=((1,),))
+
+ def test_wrong_shape_fprime_callable(self):
+ """The callables 'func' and 'deriv_func' have no 'func_name' attribute."""
+ func = ReturnShape(1)
+ deriv_func = ReturnShape((2,2))
+ assert_raises(TypeError, optimize.fsolve, func, x0=[0,1], fprime=deriv_func)
+
+ def test_wrong_shape_fprime_function(self):
+ func = lambda x: dummy_func(x, (2,))
+ deriv_func = lambda x: dummy_func(x, (3,3))
+ assert_raises(TypeError, optimize.fsolve, func, x0=[0,1], fprime=deriv_func)
+
+
class TestLeastSq(TestCase):
def setUp(self):
x = np.linspace(0, 10, 40)
@@ -125,7 +169,32 @@
assert_(ier in (1,2,3,4), 'solution not found: %s'%mesg)
assert_array_equal(p0, p0_copy)
+ def test_wrong_shape_func_callable(self):
+ """The callable 'func' has no 'func_name' attribute."""
+ func = ReturnShape(1)
+ # x0 is a list of two elements, but func will return an array with
+ # length 1, so this should result in a TypeError.
+ x0 = [1.5, 2.0]
+ assert_raises(TypeError, optimize.leastsq, func, x0)
+ def test_wrong_shape_func_function(self):
+ # x0 is a list of two elements, but func will return an array with
+ # length 1, so this should result in a TypeError.
+ x0 = [1.5, 2.0]
+ assert_raises(TypeError, optimize.leastsq, dummy_func, x0, args=((1,),))
+
+ def test_wrong_shape_Dfun_callable(self):
+ """The callables 'func' and 'deriv_func' have no 'func_name' attribute."""
+ func = ReturnShape(1)
+ deriv_func = ReturnShape((2,2))
+ assert_raises(TypeError, optimize.leastsq, func, x0=[0,1], Dfun=deriv_func)
+
+ def test_wrong_shape_Dfun_function(self):
+ func = lambda x: dummy_func(x, (2,))
+ deriv_func = lambda x: dummy_func(x, (3,3))
+ assert_raises(TypeError, optimize.leastsq, func, x0=[0,1], Dfun=deriv_func)
+
+
class TestCurveFit(TestCase):
def setUp(self):
self.y = array([1.0, 3.2, 9.5, 13.7])
More information about the Scipy-svn
mailing list