我有一个代码,可以从多个FITS文件中提取(2D图像)数据,对其进行过采样,然后尝试使用scipy的曲线拟合模块将其拟合到2D高斯函数。我首先仅在一个文件上使用了此适合的代码,并且可以正常运行。因此,我尝试添加循环,使其可以在许多文件上使用,但是现在出现错误。我已经阅读了文档,却看不到为什么会收到此特定错误,这使我认为循环可能有问题。下面是我的完整代码和错误消息。
import glob
import math
import numpy as np
from scipy import interpolate, ndimage, spatial, optimize
import matplotlib.pyplot as plt
import astropy
from astropy.nddata import Cutout2D
from astropy.io import fits
image_files = glob.glob('/home/<username>/Desktop/myfolder/*.fits')
# Get data from all FITS images (each band has 27 images)
initial_data = [[], [], [], []]
for image in image_files:
img_data = np.nan_to_num(fits.getdata(image))
if 'w1' in image:
initial_data[0].append(img_data)
elif 'w2' in image:
initial_data[1].append(img_data)
elif 'w3' in image:
initial_data[2].append(img_data)
elif 'w4' in image:
initial_data[3].append(img_data) # list of arrays
# Make small centered cutouts and oversample 5x5 (or 5.5x5.5 for w4)
def oversample(band, N): # pixels -> NxN pixels
postage_stamps, oversampled_data = [], []
size = 25
for i in range(len(initial_data[band-1])):
geom_ctr = (np.shape(initial_data[band-1][i])[0]//2, np.shape(initial_data[band-1][i])[1]//2)
cutout = Cutout2D(initial_data[band-1][i], geom_ctr, size).data
postage_stamps.append(cutout)
for data_set in postage_stamps:
Y, X = np.shape(data_set)
x = np.linspace(0, 0.5, X)
y = np.linspace(0, 0.5, Y)
f = interpolate.interp2d(x, y, data_set, kind='cubic')
Xnew = np.linspace(0, 0.5, X*N)
Ynew = np.linspace(0, 0.5, Y*N)
new_data = f(Xnew, Ynew)
oversampled_data.append(new_data)
return oversampled_data # list of arrays
resampled_data = [oversample(1, 5), oversample(2, 5), oversample(3, 5), oversample(4, 5.5)]
# Fit to 2D Gaussian
def gaussian_func(xy, x0, y0, sigma, amp): # (x_0, y_0) is center
x, y = xy
offset = np.min(data_set)
a = 1/(2 * sigma**2)
c = 1/(2 * sigma**2)
exp_term = a * (x-x0)**2
exp_term += c * (y-y0)**2
return offset + amp * np.exp(-exp_term)
def generate(x0, y0, sigma, amp):
x = np.arange(int(np.min(data_set)), max(x0, y0)*2 + sigma, 1)
y = np.arange(int(np.min(data_set)), max(x0, y0)*2 + sigma, 1)
xx, yy = np.meshgrid(x, y)
z = gaussian_func((xx, yy), x0, y0, sigma, amp)
return xx, yy, z
def fit_to_model(data):
# Guess parameters - I want it to look in the middle of the image
sigma = np.std(data)
x0 = np.shape(data)[1]//2
y0 = np.shape(data)[0]//2 #(y, x) center of image
amp = np.max(data)
guesses = [x0, y0, sigma, amp]
xx, yy, z = generate(x0, y0, sigma, amp)
pred_params, uncert_cov = optimize.curve_fit(gaussian_func, (xx.ravel(), yy.ravel()), z.ravel(), p0=guesses)
return pred_params
params_by_band = [[], [], [], []]
for i in range(4):
for data_set in resampled_data[i]:
params_by_band[i].append(fit_to_model(data_set))
奇怪的是,它运行良好,直到到达resampled_data[2]
中对应于第三波段的第一个数据集为止。我检查了这个数据,它是一个与以前形状相同的2D数组(无论如何,它必须是相同的形状,因为在代码前面它们都被制成了相同大小的切口)。然后我得到这个错误:
Traceback (most recent call last):
File "./gaussian_v3", line 107, in <module>
params_by_band[i].append(fit_to_model(data_set))
File "./gaussian_v3", line 95, in fit_to_model
pred_params, uncert_cov = optimize.curve_fit(gaussian_func, (xx.ravel(), yy.ravel()), z.ravel(), p0=guesses)
File "/usr/local/anaconda3/lib/python3.6/site-packages/scipy/optimize/minpack.py", line 751, in curve_fit
res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
File "/usr/local/anaconda3/lib/python3.6/site-packages/scipy/optimize/minpack.py", line 386, in leastsq
raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m))
TypeError: Improper input: N=4 must not exceed M=0
我真的不明白是什么原因造成的。有人可以帮忙吗?