在python中拟合多元curve_fit

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/20769340/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-18 21:10:46  来源:igfitidea点击:

fitting multivariate curve_fit in python

pythonscipycurve-fitting

提问by user3133865

I'm trying to fit a simple function to two arrays of independent data in python. I understand that I need to bunch the data for my independent variables into one array, but something still seems to be wrong with the way I'm passing variables when I try to do the fit. (There are a couple previous posts related to this one, but they haven't been much help.)

我正在尝试将一个简单的函数拟合到 python 中的两个独立数据数组。我知道我需要将自变量的数据集中到一个数组中,但是当我尝试进行拟合时,我传递变量的方式似乎仍然有问题。(之前有几篇与此相关的帖子,但它们并没有太大帮助。)

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x_3d, a, b, c, d):
    return a + b*x_3d[0,:] + c*x_3d[1,:] + d*x_3d[0,:]*x_3d[1,:]

x_3d = np.array([[1,2,3],[4,5,6]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d[:2,:], x_3d[2,:], p0)
print ' fit coefficients:\n', fitParams

The error I get reads,

我读到的错误,

raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m)) 
TypeError: Improper input: N=4 must not exceed M=3

What is Mthe length of? Is Nthe length of p0? What am I doing wrong here?

什么是M长度?是N的长度p0?我在这里做错了什么?

采纳答案by chthonicdaemon

N and M are defined in the helpfor the function. N is the number of data points and M is the number of parameters. Your error therefore basically means you need at least as many data points as you have parameters, which makes perfect sense.

N 和 M在函数的帮助中定义。N 是数据点的数量,M 是参数的数量。因此,您的错误基本上意味着您至少需要与参数一样多的数据点,这是完全有道理的。

This code works for me:

这段代码对我有用:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x, a, b, c, d):
    return a + b*x[0] + c*x[1] + d*x[0]*x[1]

x_3d = np.array([[1,2,3,4,6],[4,5,6,7,8]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d, x_3d[1,:], p0)
print ' fit coefficients:\n', fitParams

I have included more data. I have also changed fitFuncto be written in a form that scans as only being a function of a single x - the fitter will handle calling this for all the data points. The code as you posted also referenced x_3d[2,:], which was causing an error.

我已经包含了更多的数据。我也更改fitFunc为以扫描为仅作为单个 x 的函数的形式编写 - 钳工将处理对所有数据点的调用。您发布的代码也引用了x_3d[2,:],这导致了错误。

回答by Chris Young

The default curve_fitmethod needs you to have fewer parameters for the fitted function fitFuncthan data points. I had the same problem fitting a function that took 15 parameters in total and I had only 13 data points. The solution is to use another method (e.g. dogboxor trf).

默认curve_fit方法需要拟合函数的参数fitFunc少于数据点。我在拟合一个总共有 15 个参数的函数时遇到了同样的问题,我只有 13 个数据点。解决方案是使用另一种方法(例如dogboxtrf)。