我想显示0到9的单个数字,每行显示5个示例。这张照片就是一个例子:example以我为例,它产生10行5列。
我设法显示了前50张图像(很抱歉,该图像,我无法在stackoverflow上格式化代码):code
如何将带有标签的每个数字打印成一行? 我已经尝试了几个小时,但除了使用numpy.take()外,我一无所知,但我不知道如何。我已经用谷歌搜索了很多,没有可用的结果。
谢谢!
答案 0 :(得分:7)
首先,您需要数据集:
dataset = keras.datasets.mnist.load_data()
然后将其拆分:
X_train = dataset[0][0]
y_train = dataset[0][1]
X_test = dataset[1][0]
y_test = dataset[1][1]
然后您可以创建一个数字索引字典
这里我使用测试数据集。您可以根据需要使用火车,只需替换y_train:
digits = {}
for i in range(10):
digits[i] = np.where(y_test==i)[0][:5]
digits
此字典将如下所示:
{0: array([ 3, 10, 13, 25, 28], dtype=int64),
1: array([ 2, 5, 14, 29, 31], dtype=int64),
2: array([ 1, 35, 38, 43, 47], dtype=int64),
3: array([18, 30, 32, 44, 51], dtype=int64),
4: array([ 4, 6, 19, 24, 27], dtype=int64),
5: array([ 8, 15, 23, 45, 52], dtype=int64),
6: array([11, 21, 22, 50, 54], dtype=int64),
7: array([ 0, 17, 26, 34, 36], dtype=int64),
8: array([ 61, 84, 110, 128, 134], dtype=int64),
9: array([ 7, 9, 12, 16, 20], dtype=int64)}
最后,您将创建一个图形和子图,如下所示:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(10, 5, sharex='col', sharey='row')
for i in range(10):
for j in range(5):
ax[i, j].imshow(X_test[digits[i][j]])