matplotlib热图 - 绘制奇数指数

时间:2016-12-08 01:14:59

标签: python numpy matplotlib plot heatmap

我想创建一个热图,用于绘制具有不同内核大小的卷积神经网络的分类精度。在我对CNN的特定实现中,我只使用奇数大小的滤波器,但绘制的热图将绘图元素放在奇数和偶数位置。什么是忽略偶数位置的正确方法,并且只在奇数位置绘制元素?

我首先定义了我感兴趣的内核宽度和高度:

search_height = range(1, 5, 2)  # [1,3]
search_width = range(1, 5, 2)  # [1,3]

预分配数组以保存精度值。我认为这可能是问题的一部分,因为它在偶数索引中存储值?

grid_accuracy = np.empty((len(search_height), len(search_width)))  # 2x2 array

然后,我获得了不同内核大小的网络准确性,并将它们存储在数组中:

for i, h in enumerate(search_height):
    for j, w in enumerate(search_width):
        cur_test_acc = main(batch_size=200, num_epochs=100, k_height=h, k_width=w)
        grid_accuracy[i, j] = cur_test_acc

最后,我使用存储的精度值绘制热图:

plt.imshow(grid_accuracy, cmap = plt.cm.hot, interpolation='none')
plt.grid(True)
plt.xlabel('Kernel Width')
plt.ylabel('Kernel Height')
plt.xticks(search_width)
plt.yticks(search_height)

问题是我最终得到的情节如下:

enter image description here

目前似乎使用grid_accuracy的水平和垂直索引作为元素的位置。

我真正想要的只是一个2x2网格,其中每个单元格的值是相应内核宽度/高度的精度。轴刻度应该是我手动定义的高度和宽度(理想情况下,水平轴在图上方刻度):

enter image description here

1 个答案:

答案 0 :(得分:0)

您需要设置勾选标签而不仅仅是他们的位置:

import numpy as np
from matplotlib import pyplot as plt

plt.imshow(np.random.randn(2, 2), interpolation='none')
plt.xticks([0, 1], [1, 3])
plt.yticks([0, 1], [1, 3])

您可以使用plt.tick_params来获取您正在寻找的格式:

plt.tick_params(length=0, labelbottom=False, labeltop=True)