[SciPy-User] multidimensional least squares fitting
Daniel Mader
danielstefanmader at googlemail.com
Wed Nov 9 12:12:22 EST 2011
Hi everyone,
I'd like to do some rather simple multidimensional curve fitting.
Simple, because usually it's only a plane, or a weak 2nd order
surface.
Here's the same question which I tried to follow, but I have no clue
how to feed the 2D arrays into leastsq():
http://stackoverflow.com/questions/529184/simple-multidimensional-curve-fitting
Likely I am just missing a small piece of information, and I'd be
happy to get a clue :)
Thanks in advance,
and here's some code to demonstrate what I want and to get started,
Daniel
import pylab
import scipy
import scipy.optimize
from mpl_toolkits.mplot3d import Axes3D
#import sys,os,platform
#if platform.system() == 'Windows':
# home = os.environ['HOMESHARE']
#elif platform.system() == 'Linux':
# home = os.environ['HOME']
#sys.path.append(home + '/python')
#sys.path.append(home + '/11_PythonWork')
#import pylabSettings
##******************************************************************************
##******************************************************************************
'''
f = p0 + p1*x + p2*y
'''
##------------------------------------------------------------------------------
def __residual(params, f, x, y):
'''
Define fit function;
Return residual error.
'''
p0, p1, p2 = params
return p0 + p1*x + p2*y - f
## load raw data (=create some dummy data):
dataX = scipy.arange(0,11,1)
dataY = dataX/10.
dataZ = 0.5 + 1.1*dataX + 1.5*dataY
dataXX, dataYY = scipy.meshgrid(dataX,dataY)
dataZZ = 0.5 + 1.1*dataXX + 1.5*dataYY
## plot data
pylab.close('all')
fig = pylab.figure()
ax = Axes3D(fig)
ax.plot_wireframe(dataXX, dataYY, dataZZ)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
pylab.show()
## guess initial values for parameters
p0 = [0., 1., 1.]
print __residual(p0, dataZZ, dataXX, dataYY)
## works but is not 2D!
p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZ,
dataX, dataY))
print p1
## doesn't work :()
p1, p_cov = scipy.optimize.leastsq(__residual, x0=p0, args=(dataZZ,
dataXX, dataYY))
print p1
More information about the SciPy-User
mailing list