Seaborn热图绘制执行时间优化

时间:2017-08-15 16:46:59

标签: numpy optimization time heatmap seaborn

我有一个4D numpy数组,认为第四维是“时间”维度。使用前两个维度将连续帧绘制为2D热图 - 您将获得“动画”。在测量执行时间时,我得到26秒的26帧,这是非常低的。如何加快下面代码的执行时间?我更喜欢使用Seaborn来创建热图,而不是使用matplotlib(即使它是后者的扩展)。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time

data = np.load('data.npy')

fig = plt.figure()
ax = fig.add_subplot(111)
im = sns.heatmap(np.zeros((256, 128)), cmap = 'viridis', vmin = 0, vmax = 90)
plt.show(block = False)

start = time.time()
for i in range (0, data[0, 0, 0, :].size):
    plt.clf()
    sns.heatmap(20*np.log10(abs(data[:, :, 2, i])), cmap = 'viridis', vmin = 0, vmax = 90)
    fig.canvas.draw()
end = time.time()

print(end - start)

1 个答案:

答案 0 :(得分:1)

下面的代码与seaborn产生完全相同的图,但速度提高了10倍(执行时间=约2秒):

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time

data = np.load('data.npy')
data = 20*np.log10(abs(data))

fig = plt.figure(figsize = (7, 7))
ax = fig.add_subplot(111)

#initialise subfigure (dimensions and parameters)
im = ax.imshow(np.zeros((256, 128)), cmap = 'viridis', vmin = 0, vmax = 90, interpolation = 'none', aspect = 'auto')

#get rid of spines and fix range of axes, rotate x-axis labels
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks(np.arange(0, 128, 5))
ax.yaxis.set_ticks(np.arange(0, 256, 10))
for tick in ax.get_xticklabels():
    tick.set_rotation(90)

#use a divider to fix the size of the colorbar
divider = make_axes_locatable(ax)
#colorbar on the right of ax. Colorbar width in % of ax and space between them is defined by pad in inches
cax = divider.append_axes('right', size = '5%', pad = 0.07) 
cb = fig.colorbar(im, cax = cax)
#remove colorbar frame/spines
cb.outline.set_visible(False)

#don't stop after each subfigure change
plt.show(block = False)

#loop through array
start = time.time()
for i in range(data[0, 0, 2, :].size):
    time.sleep(0.005)
    im.set_array(data[:, :, 0, i])  
    fig.canvas.draw()
stop = time.time()
print(stop-start)