Python:如何通过忽略零来绘制2D矩阵的热图?

时间:2016-08-05 13:29:25

标签: python python-2.7 matplotlib heatmap colormap

我有一个大小为500 X 28000的矩阵,其中包含很多零。但让我们考虑一个矩阵A的实例:

A = [[0, 0, 0, 1, 0],
    [1, 0, 0, 2, 3],
    [5, 3, 0, 0, 0],
    [5, 0, 1, 0, 3],
    [6, 0, 0, 9, 0]]

我想绘制上面矩阵的热图,但由于它包含很多零,因此热图包含几乎为空的空间,如下图所示。

如何忽略矩阵中的零并绘制热图?

以下是我尝试过的最小工作示例:

im = plt.matshow(A, cmap=pl.cm.hot, norm=LogNorm(vmin=0.01, vmax=64), aspect='auto') # pl is pylab imported a pl
plt.colorbar(im)
plt.show()

产生:

enter image description here

正如你所看到的那样是因为零空白出现了。

但是我的原始矩阵500X280000包含很多零,这使得我的色彩图几乎是白色!!

3 个答案:

答案 0 :(得分:3)

如果删除LogNorm,则会出现黑色方块而不是白色:

im = plt.matshow(A, cmap=plt.cm.hot, aspect='auto') # pl is pylab imported a pl

enter image description here

修改

在色彩映射中,总是让完整的网格填充值。这就是你实际创建网格的原因:你考虑(比如:插入)所有不完全在网格中的点。这意味着您的数据具有多个零,并且图表通过查看白色(或黑色)正确反映了该数据。如果您没有明确的理由,则忽略这些值会创建一个误导性的图表。

如果您感兴趣的值不是零,那么您需要另一种类型的图表,如norio's comment指出的那样。为此,您可能需要查看this answer

修改2

  

改编自this answer

您可以将值视为一维数组并独立绘制点,而不是使用不需要的值填充网格。

A = [[0, 0, 0, 1, 0],
    [1, 0, 0, 2, 3],
    [5, 3, 0, 0, 0],
    [5, 0, 1, 0, 3],
    [6, 0, 0, 9, 0]]
A = np.array(A)
lenx, leny = A.shape

xx = np.array( [ a for a in range(lenx) for a in range(leny) ] )   # Convert 3D to 3*1D
yy = np.array( [ a for a in range(lenx) for b in range(leny) ] )
zz = np.array( [ A[x][y] for x,y in zip(xx,yy) ] )
#---
xx = xx[zz!=0]    # Drop zeroes
yy = yy[zz!=0]
zz = zz[zz!=0]
#---
zi, yi, xi = np.histogram2d(yy, xx, bins=(10,10), weights=zz, normed=False)
zi = np.ma.masked_equal(zi, 0)

fig, ax = plt.subplots()
ax.pcolormesh(xi, yi, zi, edgecolors='black')
scat = ax.scatter(xx, yy, c=zz, s=200)
fig.colorbar(scat)
ax.margins(0.05)

plt.show()

enter image description here

答案 1 :(得分:1)

这个答案与'编辑2'的方向相同。路易斯'回答。实际上,这是它的简化版本。我发布此信息只是为了纠正我在评论中的误导性陈述。我看到一个警告,我们不应该在评论区域讨论,所以我正在使用这个答案区域。

无论如何,首先让我发布我的代码。请注意,我使用了在脚本中随机生成的较大矩阵,而不是样本矩阵A

#!/usr/bin/python
#
# This script was written by norio 2016-8-5.

import os, re, sys, random
import numpy as np

#from matplotlib.patches import Ellipse
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as img

mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['lines.markeredgewidth'] = 1.0
mpl.rcParams['axes.formatter.limits'] = (-4,4)
#mpl.rcParams['axes.formatter.limits'] = (-2,2)
mpl.rcParams['axes.labelsize'] = 'large'
mpl.rcParams['xtick.labelsize'] = 'large'
mpl.rcParams['ytick.labelsize'] = 'large'
mpl.rcParams['xtick.direction'] = 'out'
mpl.rcParams['ytick.direction'] = 'out'


############################################
#numrow=500
#numcol=280000
numrow=50
numcol=28000
# .. for testing
numelm=numrow*numcol
eps=1.0e-9
#
#numnz=int(1.0e-7*numelm)
numnz=int(1.0e-5*numelm)
# .. for testing
vmin=1.0e-6
vmax=1.0
outfigname='stackoverflow38790536.png'
############################################

### data matrix
# I am generating a data matrix here artificially.
print 'generating pseudo-data..'
random.seed('20160805')
matA=np.zeros((numrow, numcol))
for je in range(numnz):
    jr = random.uniform(0,numrow)
    jc = random.uniform(0,numcol)
    matA[jr,jc] = random.uniform(vmin,vmax)


### Actual processing for a given data will start from here
print 'processing..'

idxrow=[]
idxcol=[]
val=[]
for ii in range(numrow):
    for jj in range(numcol):
        if np.abs(matA[ii,jj])>eps:
            idxrow.append(ii)
            idxcol.append(jj)
            val.append( np.abs(matA[ii,jj]) )

print 'len(idxrow)=', len(idxrow)    
print 'len(idxcol)=', len(idxcol)    
print 'len(val)=',    len(val)    


############################################
# canvas setting for line plots 
############################################

f_size   = (8,5)

a1_left   = 0.15
a1_bottom  = 0.15
a1_width  = 0.65
a1_height = 0.80
#
hspace=0.02
#
ac_left   = a1_left+a1_width+hspace
ac_bottom = a1_bottom
ac_width  = 0.03
ac_height = a1_height

############################################
# plot 
############################################
print 'plotting..'

fig1=plt.figure(figsize=f_size)
ax1 =plt.axes([a1_left, a1_bottom, a1_width, a1_height], axisbg='w')

pc1=plt.scatter(idxcol, idxrow, s=20, c=val, cmap=mpl.cm.gist_heat_r)
# cf.
# http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter
plt.xlabel('Column Index', fontsize=18)
plt.ylabel('Row Index', fontsize=18)
ax1.set_xlim([0, numcol-1])
ax1.set_ylim([0, numrow-1])

axc =plt.axes([ac_left, ac_bottom, ac_width, ac_height], axisbg='w')
mpl.colorbar.Colorbar(axc,pc1, ticks=np.arange(0.0, 1.5, 0.1) )

plt.savefig(outfigname)
plt.close()

此脚本输出一个数字' stackoverflow38790536.png',如下所示。 scatter plot of non-zero elements

正如您在我的代码中所看到的,我使用scatter代替plot。我意识到plot命令不适合这里的任务。

我需要纠正的另一个词是row_index不需要多达140,000,000(= 500 * 280000)个元素。它只需要具有非零元素的行索引。更正确的是,列表, 在上面的代码中输入idxrow命令的idxcolvalscatter的长度等于非零元素的数量。

请注意,这两点都已在Luis'中得到妥善处理。答案。

答案 2 :(得分:0)

虽然norio的答案是正确的。我认为只需几行代码就可以给出更多快速答案:

import numpy as np
import matplotlib.pyplot as plt
A = np.asarray(A)
x,y = A.nonzero() #get the notzero indices
plt.scatter(x,y,c=A[x,y],s=100,cmap='hot',marker='s') #adjust the size to your needs
plt.colorbar()
plt.show()

enter image description here

请注意,轴是反转的。你可以通过以下方式反转它们:

ax=plt.gca()
ax.invert_xaxis()
ax.invert_yaxis()

另请注意,您现在拥有更多灵活性:

  • 您可以选择
  • 设置标记大小和标记类型以及透明度
  • 此过程更快,因为零未解析为matplotlib。