我有以下输入文件和代码/脚本在图表上添加回归线,但代码会出现此错误:ValueError: x and y must have same first dimension
。我无法弄清楚那个错误。
如何在图表上添加回归线和回归线方程?
输入文件:
-5.06 -4.27
-6.69 -7.28
-3.80 -3.51
-3.88 -2.79
-0.90 -0.81
2.10 2.59
-1.08 0.28
-5.00 -3.39
2.67 2.92
2.48 2.85
-5.10 -3.49
2.88 3.09
2.30 2.67
-3.47 -2.20
-0.90 -0.79
脚本:
#!/usr/bin/python
import numpy as np
import pylab as plot
import matplotlib.pyplot as plt
import numpy, scipy, pylab, random
from matplotlib.ticker import MultipleLocator
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
from scipy import stats
with open("input.txt", "r") as f:
x=[]
y=[]
for line in f:
if not line.strip() or line.startswith('@') or line.startswith('#'): continue
row = line.split()
x.append(float(row[0]))
y.append(float(row[1]))
fig = plt.figure(figsize=(2.2,2.2), dpi=300)
ax = plt.subplot(111)
plt.xlim(4, -8)
plt.ylim(4, -8)
ax.xaxis.set_major_locator(MaxNLocator(6))
ax.yaxis.set_major_locator(MaxNLocator(6))
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(1))
#regression part
slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
line = slope*x+intercept
plt.plot(x, line, 'r', label='fitted line')
#end
plt.scatter(x,y,color=['black','black','black','black','black','black','black','black','black','black','black','black','black','black','black'], s=3.5)
plt.savefig("output.png", dpi=300)
答案 0 :(得分:4)
您不能将列表与浮点数相乘。因此,您可以从输入列表x
创建一个numpy数组,
line = slope*np.array(x)+intercept
理想情况下,您可以使用np.genfromtxt
,
x,y = np.genfromtxt("input.txt", unpack=True)
完整示例:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from matplotlib.ticker import MaxNLocator
from scipy import stats
x,y = np.genfromtxt("input.txt", unpack=True)
fig = plt.figure(figsize=(2.2,2.2), dpi=300)
ax = plt.subplot(111)
plt.xlim(4, -8)
plt.ylim(4, -8)
ax.xaxis.set_major_locator(MaxNLocator(6))
ax.yaxis.set_major_locator(MaxNLocator(6))
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(1))
#regression part
slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
line = slope*x+intercept
plt.plot(x, line, 'r', label='y={:.2f}x+{:.2f}'.format(slope,intercept))
#end
plt.scatter(x,y, color="k", s=3.5)
plt.legend(fontsize=9)
plt.show()
答案 1 :(得分:0)