我正在尝试从matplotlib中的纸张中重现一个数字,如下所示。基本上,每个细胞都有一个百分比,百分比越高,细胞背景越暗:
下面的代码产生类似的东西,但每个单元格都是一个正方形像素,我希望它们是扁平的矩形而不是正方形,如上图所示。我怎样才能用matplotlib实现这个目标?
import numpy as np
import matplotlib.pyplot as plt
import itertools
table = np.random.uniform(low=0.0, high=1.0, size=(10,5))
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1)
plt.yticks(np.arange(10), class_names)
for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
plt.text(j, i, format(table[i,j], '.2f'),
horizontalalignment="center",
color="white" if table[i,j] > 0.5 else "black")
plt.show()
我怀疑更改aspect
和extent
imshow
可能会对我有所帮助。我不完全理解这是如何工作的,但这是我尝试过的:
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])
我意识到我还需要在单元格之间添加边框,删除刻度线,并将值更改为百分比而不是小数,我相信我自己可以做到这一点,但如果你我也想帮助我,请随意!
答案 0 :(得分:4)
在aspect="auto"
电话中使用imshow
时,您将获得非方形像素:
import numpy as np
import matplotlib.pyplot as plt
import itertools
table = np.random.uniform(low=0.0, high=1.0, size=(10,5))
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect="auto")
plt.yticks(np.arange(10), class_names)
for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
plt.text(j, i, format(table[i,j], '.2f'),
ha="center", va="center",
color="white" if table[i,j] > 0.5 else "black")
plt.show()
答案 1 :(得分:2)
经过大量的实验,我弄清楚了范围的工作原理以及它如何影响以后文本的坐标。我还添加了边框等,这段代码生成了原始风格的非常好的复制品!
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import itertools
table = np.random.uniform(low=0.0, high=1.0, size=(10,5))
class_names = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])
plt.yticks(np.arange(10)+0.5, class_names)
plt.xticks(np.arange(5)*2.8 + 1.4, ['1', '2', '3', '4', '5'])
ax = plt.axes()
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_ticks_position('none')
matplotlib.rcParams.update({'font.size': 14})
ax = plt.gca()
# Minor ticks
ax.set_xticks(np.arange(1, 5) * 2.8, minor=True);
ax.set_yticks(np.arange(1, 10, 1), minor=True);
# Gridlines based on minor ticks
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
plt.text(j*2.8+1.5, i+0.6, format(table[i,j], '.2f'),
horizontalalignment="center",
color="white" if table[i,j] > 0.5 else "black")
plt.show()
感谢DavidG和Spezi94帮助他们发表评论!