我前段时间写了一些代码,用高斯kde制作简单的密度散点图。然而,对于大于约100,000点的数据集,它只是“永远”运行(几天后我将其杀死)。一位朋友在R
给了我一些代码,可以在几秒钟内创建这样的密度图(plot_fun.R),看起来matplotlib应该可以做同样的事情。
我认为正确的地方是2d直方图,但我很难让密度变得“正确”。我修改了我在this question找到的代码来完成此任务,但是没有显示密度,看起来只有密码的可能点得到任何颜色。
这是我正在使用的代码:
# initial data
x = -np.log10(np.random.random_sample(10000))
y = -np.log10(np.random.random_sample(10000))
#histogram definition
bins = [1000, 1000] # number of bins
thresh = 3 #density threshold
#data definition
mn = min(x.min(), y.min())
mx = max(x.max(), y.max())
mn = mn-(mn*.1)
mx = mx+(mx*.1)
xyrange = [[mn, mx], [mn, mx]]
# histogram the data
hh, locx, locy = np.histogram2d(x, y, range=xyrange, bins=bins)
posx = np.digitize(x, locx)
posy = np.digitize(y, locy)
#select points within the histogram
ind = (posx > 0) & (posx <= bins[0]) & (posy > 0) & (posy <= bins[1])
hhsub = hh[posx[ind] - 1, posy[ind] - 1] # values of the histogram where the points are
xdat1 = x[ind][hhsub < thresh] # low density points
ydat1 = y[ind][hhsub < thresh]
hh[hh < thresh] = np.nan # fill the areas with low density by NaNs
f, a = plt.subplots(figsize=(12,12))
c = a.imshow(
np.flipud(hh.T), cmap='jet',
extent=np.array(xyrange).flatten(), interpolation='none',
origin='upper'
)
f.colorbar(c, ax=ax, orientation='vertical', shrink=0.75, pad=0.05)
s = a.scatter(
xdat1, ydat1, color='darkblue', edgecolor='', label=None,
picker=True, zorder=2
)
产生这个情节:
KDE代码在这里:
f, a = plt.subplots(figsize=(12,12))
xy = np.vstack([x, y])
z = sts.gaussian_kde(xy)(xy)
# Sort the points by density, so that the densest points are
# plotted last
idx = z.argsort()
x2, y2, z = x[idx], y[idx], z[idx]
s = a.scatter(
x2, y2, c=z, s=50, cmap='jet',
edgecolor='', label=None, picker=True, zorder=2
)
产生这个情节:
问题当然是这个代码在大型数据集上无法使用。
我的问题是:如何使用二维直方图来生成这样的散点图? ax.hist2d
没有产生有用的输出,因为它为整个绘图着色,而我所有努力使上面的2d直方图数据实际上正确地为绘图的密集区域着色都失败了,我总是最终得不到着色或一小部分最密集的点被着色。很明显,我只是不太了解代码。
答案 0 :(得分:0)
您的直方图代码指定了一种独特的颜色(color='darkblue'
),那么您期待什么?
我认为你也过于复杂化了。这个更简单的代码工作得很好:
import numpy as np
import matplotlib.pyplot as plt
x, y = -np.log10(np.random.random_sample((2,10**6)))
#histogram definition
bins = [1000, 1000] # number of bins
# histogram the data
hh, locx, locy = np.histogram2d(x, y, bins=bins)
# Sort the points by density, so that the densest points are plotted last
z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
idx = z.argsort()
x2, y2, z2 = x[idx], y[idx], z[idx]
plt.figure(1,figsize=(8,8)).clf()
s = plt.scatter(x2, y2, c=z2, cmap='jet', marker='.')