如何从mnist中从一系列选定的标签中创建自定义迷你批次?

时间:2019-04-18 13:25:39

标签: python mnist numpy-broadcasting

我正在尝试从MNIST创建一个小批量,其所有数字的范围为0到9。(10个元素)

我想避免循环遍历标签扇区中的所有元素以一一检查数字。

最简单的方法是什么?

我想我可以创建一个从0到9的标签数组“ all_digits”,然后将其与我的mnist_labels“ train_labels”列表进行比较。 (一维数组-n个元素)

我尝试获取所有对均等检查的矩阵。 (n x 10) 但是我不能直接使用==,也没有numpy.equal()的广播版本。

之后我也不清楚如何处理矩阵。

import numpy as np
train_labels = np.random.randint(0,10,100)
all_digits = np.arange(10)
# doing a difference for now
train_labels.reshape((-1,1)) - all_digits

0 个答案:

没有答案