我有以下2D高斯函数的定义:
# Return a gaussian distribution at an angle alpha from the x-axis
# from astroML for use with curve_fit
def mult_gaussFun_Fit((x,y),*m):
A,x0,y0,varx,vary,rho,alpha = m
X,Y = np.meshgrid(x,y)
assert rho != 1
a = 1/(2*(1-rho**2))
Z = A*np.exp(-a*((X-x0)**2/(varx)+(Y-y0)**2/(vary)-(2*rho/(np.sqrt(varx*vary)))*(X-x0)*(Y-y0)))
return Z.ravel()
我使用以下代码尝试从双变量高斯绘制的数据的curve_fit,该高斯被转换为2D直方图。我收到广播错误,我不确定为什么会这样。
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import gauss
import plotutils
# Produce a number of points in x-y from 1 distribution.
mean = [0,0]
cov = [[3,0],[0,1]]
N = 3000
x,y = np.random.multivariate_normal(mean,cov,N).T
# Prep bins for histogram
bin_size = 0.2
max_edge = 2.5*(np.sqrt(cov[0][0])+np.sqrt(cov[1][1]))
min_edge = -max_edge
bin_num = (max_edge-min_edge)/bin_size
bin_numPlus1 = bin_num + 1
bins = np.linspace(min_edge,max_edge,bin_numPlus1)
# Produce 2D histogram
H,xedges,yedges = np.histogram2d(x,y,bins,normed=False)
bin_centers_x = (xedges[:-1]+xedges[1:])/2.0
bin_centers_y = (yedges[:-1]+yedges[1:])/2.0
# Initial Guess
p0 = (H.max(),mean[0],mean[1],cov[0][0],cov[1][1],0.5,np.pi/4)
# Curve Fit parameters
coeff, var_matrix = curve_fit(gauss.mult_gaussFun_Fit,(bin_centers_x,bin_centers_y),H,p0=p0)
错误是:
Traceback (most recent call last):
File "/home/luis/Documents/SRC2014/galsim_work/2D_Gaussian_Estimate.py", line 44, in <module>
coeff, var_matrix = curve_fit(gauss.mult_gaussFun_Fit,(bin_centers_x,bin_centers_y),H,p0=p0)
File "/usr/local/lib/python2.7/dist-packages/scipy/optimize/minpack.py", line 555, in curve_fit
res = leastsq(func, p0, args=args, full_output=1, **kw)
File "/usr/local/lib/python2.7/dist-packages/scipy/optimize/minpack.py", line 369, in leastsq
shape, dtype = _check_func('leastsq', 'func', func, x0, args, n)
File "/usr/local/lib/python2.7/dist-packages/scipy/optimize/minpack.py", line 20, in _check_func
res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
File "/usr/local/lib/python2.7/dist-packages/scipy/optimize/minpack.py", line 445, in _general_function
return function(xdata, *params) - ydata
ValueError: operands could not be broadcast together with shapes (4624) (68,68)
答案 0 :(得分:0)
我只需要执行
H = H.ravel()
并且解决了它。