如何在matplotlib中制作由密度着色的散点图?

时间:2013-11-20 19:39:31

标签: python matplotlib

我想制作一个散点图,其中每个点都由附近点的空间密度着色。

我遇到了一个非常相似的问题,它使用R:

显示了一个例子

R Scatter Plot: symbol color represents number of overlapping points

使用matplotlib在python中完成类似内容的最佳方法是什么?

4 个答案:

答案 0 :(得分:112)

除了@askewchan建议的hist2dhexbin之外,您还可以使用与您链接的问题中接受的答案相同的方法。

如果你想这样做:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)

# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=100, edgecolor='')
plt.show()

enter image description here

如果您希望以密度的顺序绘制点,以便最密集的点始终位于顶部(类似于链接的示例),只需按z值对它们进行排序即可。我也会在这里使用较小的标记尺寸,因为它看起来好一点:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)

# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=50, edgecolor='')
plt.show()

enter image description here

答案 1 :(得分:27)

您可以制作直方图:

import numpy as np
import matplotlib.pyplot as plt

# fake data:
a = np.random.normal(size=1000)
b = a*3 + np.random.normal(size=1000)

plt.hist2d(a, b, (50, 50), cmap=plt.cm.jet)
plt.colorbar()

2dhist

答案 2 :(得分:7)

另外,如果点的数量使KDE计算太慢,则可以在np.histogram2d中插入颜色:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interpn

def density_scatter( x , y, ax = None, sort = True, bins = 20, **kwargs )   :
    """
    Scatter plot colored by 2d histogram
    """
    if ax is None :
        fig , ax = plt.subplots()
    data , x_e, y_e = np.histogram2d( x, y, bins = bins)
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False )

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    ax.scatter( x, y, c=z, **kwargs )
    return ax


if "__main__" == __name__ :

    x = np.random.normal(size=100000)
    y = x * 3 + np.random.normal(size=100000)
    density_scatter( x, y, bins = [30,30] )

答案 3 :(得分:5)

要绘制> 100k个数据点?

使用accepted answergaussian_kde()将花费很多时间。在我的计算机上,10万行花了大约 11分钟。在这里,我将添加两个替代方法(mpl-scatter-densitydatashader),并将给定的答案与相同的数据集进行比较。

在下面,我使用了一个10万行的测试数据集:

import matplotlib.pyplot as plt
import numpy as np

# Fake data for testing
x = np.random.normal(size=100000)
y = x * 3 + np.random.normal(size=100000)

输出和计算时间比较

下面是不同方法的比较。

1: mpl-scatter-density

安装

pip install mpl-scatter-density

示例代码

import mpl_scatter_density # adds projection='scatter_density'
from matplotlib.colors import LinearSegmentedColormap

# "Viridis-like" colormap with white background
white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
    (0, '#ffffff'),
    (1e-20, '#440053'),
    (0.2, '#404388'),
    (0.4, '#2a788e'),
    (0.6, '#21a784'),
    (0.8, '#78d151'),
    (1, '#fde624'),
], N=256)

def using_mpl_scatter_density(fig, x, y):
    ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
    density = ax.scatter_density(x, y, cmap=white_viridis)
    fig.colorbar(density, label='Number of points per pixel')

fig = plt.figure()
using_mpl_scatter_density(fig, x, y)
plt.show()

绘制时间为0.05秒: using mpl-scatter-density

并且放大看起来非常不错: zoom in mpl-scatter-density

2: datashader

pip install "git+https://github.com/nvictus/datashader.git@mpl"

代码(dsshow here的来源):

from functools import partial

import datashader as ds
from datashader.mpl_ext import dsshow
import pandas as pd

dyn = partial(ds.tf.dynspread, max_px=40, threshold=0.5)

def using_datashader(ax, x, y):

    df = pd.DataFrame(dict(x=x, y=y))
    da1 = dsshow(df, ds.Point('x', 'y'), spread_fn=dyn, aspect='auto', ax=ax)
    plt.colorbar(da1)

fig, ax = plt.subplots()
using_datashader(ax, x, y)
plt.show()
  • 花费了0.83 s来绘制此图像:

enter image description here

缩放后的图像看起来很棒!

enter image description here

3: scatter_with_gaussian_kde

def scatter_with_gaussian_kde(ax, x, y):
    # https://stackoverflow.com/a/20107592/3015186
    # Answer by Joel Kington

    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)

    ax.scatter(x, y, c=z, s=100, edgecolor='')
  • 花了11分钟画了这个: scatter_with_gaussian_kde

4: using_hist2d

import matplotlib.pyplot as plt
def using_hist2d(ax, x, y, bins=(50, 50)):
    # https://stackoverflow.com/a/20105673/3015186
    # Answer by askewchan
    ax.hist2d(x, y, bins, cmap=plt.cm.jet)

  • 绘制此容器=(50,50)用了0.021 s: using_hist2d_50
  • 绘制此容器=(1000,1000)用了0.173 s: using_hist2d_1000
  • 缺点:放大后的数据看起来不如mpl-scatter-density或datashader那样好。另外,您还必须自己确定垃圾箱的数量。

zoomed in hist2d 1000bins

5: density_scatter

  • 代码与answerGuillaume中的代码相同。
  • 用bins =(50,50)绘制此图像花费了0.073 s: density_scatter_50bins
  • 用bins =(1000,1000)绘制此图像花费了0.368 s: density_scatter_1000bins