我正在尝试从MNIST创建一个小批量,其所有数字的范围为0到9。(10个元素)
我想避免循环遍历标签扇区中的所有元素以一一检查数字。
最简单的方法是什么?
我想我可以创建一个从0到9的标签数组“ all_digits”,然后将其与我的mnist_labels“ train_labels”列表进行比较。 (一维数组-n个元素)
我尝试获取所有对均等检查的矩阵。 (n x 10) 但是我不能直接使用==,也没有numpy.equal()的广播版本。
之后我也不清楚如何处理矩阵。
import numpy as np
train_labels = np.random.randint(0,10,100)
all_digits = np.arange(10)
# doing a difference for now
train_labels.reshape((-1,1)) - all_digits