scipy.optimize.curve_fit()没有正确传递2d数据

时间:2018-02-21 06:11:21

标签: python scipy curve-fitting

我正在尝试使用curve_fit()将2d函数拟合到数组中,但我在使用ydata参数时遇到了一些问题。 (基本上我的功能是尝试将两个2d斜率拟合到数据中,一个覆盖整个数据集,另一个斜面平坦到div1, div2定义的点。)

这是我的代码:

def two_slope_regression((X1,X2), data, div1, div2):
    #Set up parameters
    m = 1016
    print 'data', data
    print 'data type', type(data)
    div1 = int(div1)
    div2 = int(div2)
    print 'div1, div2', div1, div2
    X1div = np.zeros_like(X1)
    X2div = np.zeros_like(X2)
    X1div[div1:] = X1[div1:]
    X2div[div2:] = X2[div2:]
    Y = np.array(data)

    #Regression
    X = np.hstack((np.reshape(X1, (m*m, 1)), np.reshape(X2, (m*m, 1)),
               np.reshape(X1div, (m*m, 1)), np.reshape(X2div, (m*m, 1))) 
             )
    X = np.hstack((np.ones((m*m, 1)), X))
    YY = np.reshape(Y, (m*m, 1))
    theta = np.dot(np.dot(np.linalg.pinv(np.dot(X.transpose(), X)), X.transpose()), YY)
    print 'theta', theta

    y = theta[0] + theta[1]*X1 + theta[2]*X2 + theta[3]*X1div + theta[4]*X2div
    return y.ravel()

def fit_slope(data):
    m=1016
    print 'data type', type(data)
    X1, X2 = np.mgrid[:m, :m]
    print 'data shape', data.shape
    #init_guess = (508, 508)
    popt, pcov = opt.curve_fit(two_slope_regression, xdata=(X1,X2), ydata=data)#, p0=(508,508))
    print 'popt, pcov', popt, pcov
    return

if __name__ == '__main__':
    theta_samp = np.random.rand(5)
    x1,x2 = np.mgrid[:1016, :1016]
    x1div = np.zeros_like(x1)
    x1div[508:] = x1[508:]
    x2div = np.zeros_like(x2)
    x2div[508:] = x2[508:]
    sample_data = np.random.randn(1016, 1016)*theta_samp[0] + theta_samp[1]*x1 + theta_samp[2]*x2 + theta_samp[3]*x1div + theta_samp[4]*x2div    
    fit_slope(sample_data)

如您所见,我已经包含print语句来检查不同阶段的数据数组的类型和形状。第一次是fit_slope() - 类型为<type 'numpy.ndarray'>,形状为(1016,1016)。到现在为止还挺好。 第二次是curve_fit()调用two_slope_regression()并在错误data 1.0之前输出data type <type 'numpy.float64'>ValueError: cannot reshape array of size 1 into shape (1032256,1)

所以某种方式ydata参数没有正确传递给fit函数。出了什么问题?

我不知道这是否相关,但我之前遇到了p0参数(现已注释掉)的另一个问题。当我尝试传递fit参数时,我收到以下错误:

TypeError: two_slope_regression() takes exactly 4 arguments (3 given)

据我所知,我正在使用curve_fit()作为文档要求。我缺少什么?

0 个答案:

没有答案