我有一个问题,我想绘制一个数据分布,其中一些值频繁出现,而另一些则很少出现。总点数约为30.000。将这样的图形渲染为png或(禁止上帝)pdf会花费很多时间,并且pdf太大而无法显示。
所以我只想对图进行数据二次采样。我想实现的是删除许多重叠的点(密度高的点),但保留那些密度低的点的几率几乎为1。
现在,numpy.random.choice
允许您指定一个概率向量,该向量是我根据几项调整根据数据直方图计算得出的。但是我似乎无法选择,这样才能真正保留稀有点。
我已经附加了数据的图像;分布的右尾少了几个数量级,所以我想保留这些。数据是3d,但密度仅来自一维,因此我可以用它来衡量给定位置上有多少个点
答案 0 :(得分:2)
一种可能的方法是使用kernel density estimation (KDE)建立数据的估计概率分布,然后根据每个点的估计概率密度的倒数进行采样(或其他函数,估计概率越大,该函数越小密度是)。有a few tools to compute a (KDE) in Python,一个简单的是scipy.stats.gaussian_kde
。这是一个例子:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
np.random.seed(100)
# Make some random Gaussian data
data = np.random.multivariate_normal([1, 1], [[1, 0], [0, 1]], size=1000)
# Compute KDE
kde = scipy.stats.gaussian_kde(data.T)
# Choice probabilities are computed from inverse probability density in KDE
p = 1 / kde.pdf(data.T)
# Normalize choice probabilities
p /= np.sum(p)
# Make sample using choice probabilities
idx = np.random.choice(np.arange(len(data)), size=100, replace=False, p=p)
sample = data[idx]
# Plot
plt.figure()
plt.scatter(data[:, 0], data[:, 1], label='Data', s=10)
plt.scatter(sample[:, 0], sample[:, 1], label='Sample', s=7)
plt.legend()
输出:
答案 1 :(得分:1)
请考虑以下功能。它将沿轴和
将数据分装在相等的分箱中这可以将原始数据保留在低密度区域中,但可以大大减少要在高密度区域中绘制的数据量。同时,所有特征都保留有足够密集的分箱。
import numpy as np; np.random.seed(42)
def filt(x,y, bins):
d = np.digitize(x, bins)
xfilt = []
yfilt = []
for i in np.unique(d):
xi = x[d == i]
yi = y[d == i]
if len(xi) <= 2:
xfilt.extend(list(xi))
yfilt.extend(list(yi))
else:
xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
yfilt.extend([yi.max(), yi.min()])
# prepend/append first/last point if necessary
if x[0] != xfilt[0]:
xfilt = [x[0]] + xfilt
yfilt = [y[0]] + yfilt
if x[-1] != xfilt[-1]:
xfilt.append(x[-1])
yfilt.append(y[-1])
sort = np.argsort(xfilt)
return np.array(xfilt)[sort], np.array(yfilt)[sort]
为说明这一概念,让我们使用一些玩具数据
x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
y = np.array([4,2,5,3, 7,3,5,5, 2, 4, 5, 2,5,3])
bins = np.linspace(0,30,7)
然后调用xf, yf = filt(x,y,bins)
并绘制原始数据和过滤后的数据,得出:
问题的用例包含30000个数据点,将在下面显示。使用提出的技术将允许将绘制的点的数量从30000减少到大约500。当然,该数量将取决于使用中的装箱-这里是300个装箱。在这种情况下,该函数大约需要10毫秒才能计算出来。这不是超级快,但是与绘制所有点相比还是很大的进步。
import matplotlib.pyplot as plt
# Generate some data
x = np.sort(np.random.rayleigh(3, size=30000))
y = np.cumsum(np.random.randn(len(x)))+250
# Decide for a number of bins
bins = np.linspace(x.min(),x.max(),301)
# Filter data
xf, yf = filt(x,y,bins)
# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8),
gridspec_kw=dict(height_ratios=[1,2,2]))
ax1.hist(x, bins=bins)
ax1.set_yscale("log")
ax1.set_yticks([1,10,100,1000])
ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))
ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))
for ax in [ax2, ax3]:
ax.legend()
plt.show()