我不确定标题是否足够清晰(想不出更好的询问方法),但基本上,我想在形成标准差的两条线的边界之间填充阴影区域一组数据。在以下示例中,我在数据的x轴上只有4个点,但这太粗糙了,无法平滑地填充两条标准偏差线。
似乎我需要编写一些额外的代码集,以根据现有的标准偏差线在更精细的x轴点集上跟踪y轴值(例如,可能有10个额外点x 4数据集?)。
但是我想知道是否有更有效的方法?
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]
AllData = [DataA, DataB, DataC, DataD]
y1 = [np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)]
x1 = np.arange(len(y1))
# Calculate SD for data
SDList = []
for SDCal in range(len(AllData)):
SDList.append(np.std(AllData[SDCal], ddof=1))
plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')
for SDFill in range(len(y1)):
plt.fill_between([x1[SDFill]-0.5,x1[SDFill]+0.5], y1[SDFill]-SDList[SDFill], y1[SDFill]+SDList[SDFill],
alpha=0.15, facecolor='#0042ff')
y1Lower = []
y1Higher = []
for sort in range(len(y1)):
y1Higher.append(y1[sort] + SDList[sort])
y1Lower.append(y1[sort] - SDList[sort])
plt.plot(x1, y1Lower, color='black', linestyle='dashed')
plt.plot(x1, y1Higher, color='black', linestyle='dashed', label = 'Std Dev')
plt.legend(loc='best')
plt.show()
蓝色阴影区域是相应数据集的标准偏差。但是如何使阴影区域适合两条线而不是如图所示的条形?
答案 0 :(得分:1)
问题出在循环内调用fill_between
的方式。如果只用一次调用所有数据来调用fill_between
,就会得到更好的结果。
from scipy import stats
DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]
AllData = [DataA, DataB, DataC, DataD]
y1 = [np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)]
x1 = np.arange(len(y1))
# Calculate SD for data
SDList = []
for SDCal in range(len(AllData)):
SDList.append(np.std(AllData[SDCal], ddof=1))
plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')
y1Lower = []
y1Higher = []
for sort in range(len(y1)):
y1Higher.append(y1[sort] + SDList[sort])
y1Lower.append(y1[sort] - SDList[sort])
plt.fill_between(x1, y1Lower, y1Higher,
alpha=0.15, facecolor='#0042ff')
plt.plot(x1, y1Lower, color='black', linestyle='dashed')
plt.plot(x1, y1Higher, color='black', linestyle='dashed', label = 'Std Dev')
plt.legend(loc='best')
plt.show()
顺便说一句,如果您使用numpy数组而不是列表,则可以大大简化代码:
from scipy import stats
DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]
AllData = [DataA, DataB, DataC, DataD]
y1 = np.array([np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)])
x1 = np.arange(len(y1))
# Calculate SD for data
SDList = np.array([np.std(a) for a in AllData])
plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')
plt.fill_between(x1, y1-SDList, y1+SDList,
alpha=0.15, facecolor='#0042ff')
plt.plot(x1, y1-SDList, color='black', linestyle='dashed')
plt.plot(x1, y1+SDList, color='black', linestyle='dashed', label = 'Std Dev')
plt.legend(loc='best')
plt.show()