如何在X轴分辨率不同的两条线之间进行fill_between?

时间:2018-07-23 07:54:10

标签: python-2.7 matplotlib

我不确定标题是否足够清晰(想不出更好的询问方法),但基本上,我想在形成标准差的两条线的边界之间填充阴影区域一组数据。在以下示例中,我在数据的x轴上只有4个点,但这太粗糙了,无法平滑地填充两条标准偏差线。

似乎我需要编写一些额外的代码集,以根据现有的标准偏差线在更精细的x轴点集上跟踪y轴值(例如,可能有10个额外点x 4数据集?)。

但是我想知道是否有更有效的方法?

import matplotlib.pyplot as plt
import numpy as np

from scipy import stats

DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]

AllData = [DataA, DataB, DataC, DataD]
y1 = [np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)]

x1 = np.arange(len(y1))

# Calculate SD for data
SDList = []
for SDCal in range(len(AllData)):
    SDList.append(np.std(AllData[SDCal], ddof=1))

plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')

for SDFill in range(len(y1)):
    plt.fill_between([x1[SDFill]-0.5,x1[SDFill]+0.5], y1[SDFill]-SDList[SDFill], y1[SDFill]+SDList[SDFill],
        alpha=0.15, facecolor='#0042ff')

y1Lower = []
y1Higher = []

for sort in range(len(y1)):
    y1Higher.append(y1[sort] + SDList[sort])
    y1Lower.append(y1[sort] - SDList[sort])

plt.plot(x1, y1Lower, color='black', linestyle='dashed')
plt.plot(x1, y1Higher, color='black', linestyle='dashed', label = 'Std Dev')

plt.legend(loc='best')

plt.show()

蓝色阴影区域是相应数据集的标准偏差。但是如何使阴影区域适合两条线而不是如图所示的条形?

enter image description here

1 个答案:

答案 0 :(得分:1)

问题出在循环内调用fill_between的方式。如果只用一次调用所有数据来调用fill_between,就会得到更好的结果。

from scipy import stats

DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]

AllData = [DataA, DataB, DataC, DataD]
y1 = [np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)]

x1 = np.arange(len(y1))

# Calculate SD for data
SDList = []
for SDCal in range(len(AllData)):
    SDList.append(np.std(AllData[SDCal], ddof=1))

plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')

y1Lower = []
y1Higher = []

for sort in range(len(y1)):
    y1Higher.append(y1[sort] + SDList[sort])
    y1Lower.append(y1[sort] - SDList[sort])


plt.fill_between(x1, y1Lower, y1Higher,
    alpha=0.15, facecolor='#0042ff')

plt.plot(x1, y1Lower, color='black', linestyle='dashed')
plt.plot(x1, y1Higher, color='black', linestyle='dashed', label = 'Std Dev')

plt.legend(loc='best')

plt.show()

enter image description here

顺便说一句,如果您使用numpy数组而不是列表,则可以大大简化代码:

from scipy import stats

DataA = [1,3,5,7,9,11]
DataB = [2,6,4,8,9,10]
DataC = [6,3,5,7,9,19]
DataD = [9,10,13,12,11]

AllData = [DataA, DataB, DataC, DataD]
y1 = np.array([np.mean(DataA), np.mean(DataB), np.mean(DataC), np.mean(DataD)])
x1 = np.arange(len(y1))

# Calculate SD for data
SDList = np.array([np.std(a) for a in AllData])

plt.plot(x1, y1, color='red', linewidth=2, label = 'Data')

plt.fill_between(x1, y1-SDList, y1+SDList,
    alpha=0.15, facecolor='#0042ff')

plt.plot(x1, y1-SDList, color='black', linestyle='dashed')
plt.plot(x1, y1+SDList, color='black', linestyle='dashed', label = 'Std Dev')

plt.legend(loc='best')

plt.show()