从Python MatPlotLib图中提取误差线和点位置

时间:2017-09-18 02:06:15

标签: python matplotlib plot extract seaborn

如果我有一个Python MatPlotLib图(例如, matplotlib.axes._subplots.AxesSubplot 对象),有没有办法从中提取点和误差线的位置?即我想得到包含x,y坐标和y错误的数组。

示例:

import numpy as np
import seaborn as sb
x = np.random.uniform(-2, 2, 10000)
y = np.random.normal(x**2, np.abs(x) + 1)
p = sb.regplot(x=x, y=y, x_bins=10, fit_reg=None)

如何从点和误差线的'p'位置提取?

感谢您的帮助!

2 个答案:

答案 0 :(得分:1)

如果您知道这些点位于误差线的中心(对于此示例,它看起来像是这样),那么应该这样做:

import numpy as np
import seaborn as sb
x = np.random.uniform(-2, 2, 10000)
y = np.random.normal(x**2, np.abs(x) + 1)
p = sb.regplot(x=x, y=y, x_bins=10, fit_reg=None)

def get_data(p):
    x_list = []
    lower_list = []
    upper_list = []
    for line in p.lines:
        x_list.append(line.get_xdata()[0])
        lower_list.append(line.get_ydata()[0])
        upper_list.append(line.get_ydata()[1])
    y = 0.5 * (np.asarray(lower_list) + np.asarray(upper_list))
    y_error = np.asarray(upper_list) - y
    x = np.asarray(x_list)
    return x, y, y_error

get_data(p)

这里返回的y_error将是误差条的大小。

答案 1 :(得分:1)

错误栏数据存储在p.lines中,因为seaborn使用plt.plot绘制它们。

您可以使用line.get_xdata()line.get_ydata()访问其职位。

点数据存储在p.collections中,因为它们是使用plt.scatter在seaborn内部绘制的。

PathCollection对象获取点位置需要一个额外的步骤,如此答案所示:Get positions of points in PathCollection created by scatter():即您必须先设置offset_position,然后才能访问{ {1}}。

这是一个从中获取点数据和错误栏数据的示例  matplotlib offsets对象,Axes

p

此处import numpy as np import seaborn as sb import matplotlib.pyplot as plt x = np.random.uniform(-2, 2, 10000) y = np.random.normal(x**2, np.abs(x) + 1) p = sb.regplot(x=x, y=y, x_bins=10, fit_reg=None) # First, get the positions of the points: coll = p.collections[0] coll.set_offset_position('data') points_xy = coll.get_offsets() print points_xy #[[-1.65295679 3.05723876] # [-1.29981986 1.60258005] # [-0.94417279 0.8999881 ] # [-0.56964819 0.38035406] # [-0.20253243 0.0774201 ] # [ 0.15535504 0.024336 ] # [ 0.5362322 0.30849082] # [ 0.90482003 0.85788122] # [ 1.26136841 1.66294418] # [ 1.63048127 3.02934186]] # Next, get the positions of the errorbars xerr = [] yerr = [] for line in p.lines: xerr.append(line.get_xdata()[0]) yerr.append(line.get_ydata().tolist()) print xerr # [-1.6529567859649865, -1.2998198636006264, -0.94417278886439027, -0.56964818931133276, -0.20253243328132031, 0.15535504153419355, 0.53623219583456194, 0.90482002911787607, 1.2613684083224488, 1.6304812696399549] print yerr # [[2.908807029542707, 3.200571530218434], [1.4449980200239572, 1.751504207194087], [0.7633753040974505, 1.029774999216172], [0.26593411110949544, 0.4753543268237353], [-0.0030674495857816496, 0.15582564460187567], [-0.052610243112427575, 0.09899773706322114], [0.21019700161329888, 0.41120457637300634], [0.7328000635837721, 0.9826379405190817], [1.508513523393156, 1.8184617796582343], [2.885113765027557, 3.1670479251950376]] plt.show() 是点的points_xy坐标列表,(x,y)是错误栏的x坐标(当然,与x坐标相同)在xerr)中,points_xy是y坐标对的列表:每个错误栏的顶部和底部。