理解==应用于NumPy数组

时间:2016-04-10 05:06:54

标签: python python-2.7 numpy

我是Python的新手,我正在学习 TensorFlow 。在使用 notMNIST数据集的教程中,他们给出了将标签矩阵转换为n个编码数组的示例代码。

目标是获取一个由标签整数0 ... 9组成的数组,并返回一个矩阵,其中每个整数已转换为一个n编码数组,如下所示:

0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
...

他们提供的代码是:

# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)

但是,我根本不了解这段代码是如何做到的。看起来它只是生成0到9范围内的整数数组,然后将其与标签矩阵进行比较,并将结果转换为浮点数。 ==运算符如何生成一个n编码矩阵

2 个答案:

答案 0 :(得分:28)

这里有一些事情:numpy的矢量操作,添加单例轴和广播。

首先,你应该能够看到==如何发挥魔力。

我们假设我们从一个简单的标签数组开始。 ==以矢量化方式运行,这意味着我们可以将整个数组与标量进行比较,并获得由每个元素比较的值组成的数组。例如:

>>> labels = np.array([1,2,0,0,2])
>>> labels == 0
array([False, False,  True,  True, False], dtype=bool)
>>> (labels == 0).astype(np.float32)
array([ 0.,  0.,  1.,  1.,  0.], dtype=float32)

首先我们得到一个布尔数组,然后我们强制浮点数:Python中的False == 0,True == 1。所以我们得到一个0的数组,labels不等于0和1。

但是与0相比没什么特别的,我们可以比较1或2或3而不是类似的结果:

>>> (labels == 2).astype(np.float32)
array([ 0.,  1.,  0.,  0.,  1.], dtype=float32)

事实上,我们可以遍历每个可能的标签并生成这个数组。我们可以使用listcomp:

>>> np.array([(labels == i).astype(np.float32) for i in np.arange(3)])
array([[ 0.,  0.,  1.,  1.,  0.],
       [ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  1.]], dtype=float32)

但这并没有真正利用numpy。我们想要做的是将每个可能的标签与每个元素进行比较,IOW进行比较

>>> np.arange(3)
array([0, 1, 2])

>>> labels
array([1, 2, 0, 0, 2])

这里是numpy广播的神奇之处。现在,labels是一个形状的一维物体(5,)。如果我们使它成为一个二维形状对象(5,1),那么操作将"广播"在最后一个轴上,我们得到一个形状(5,3)的输出,其结果是将范围中的每个条目与标签的每个元素进行比较。

首先我们可以添加"额外"使用labels(或None)转换为np.newaxis的轴,改变其形状:

>>> labels[:,None]
array([[1],
       [2],
       [0],
       [0],
       [2]])
>>> labels[:,None].shape
(5, 1)

然后我们可以进行比较(这是我们之前看到的安排的转置,但这并不重要)。

>>> np.arange(3) == labels[:,None]
array([[False,  True, False],
       [False, False,  True],
       [ True, False, False],
       [ True, False, False],
       [False, False,  True]], dtype=bool)
>>> (np.arange(3) == labels[:,None]).astype(np.float32)
array([[ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  0.,  1.]], dtype=float32)

numpy中的广播非常强大,非常值得一读。

答案 1 :(得分:0)

简而言之,==应用于numpy数组意味着将元素方式==应用于数组。结果是一系列布尔值。这是一个例子:

>>> b = np.array([1,0,0,1,1,0])
>>> b == 1
array([ True, False, False,  True,  True, False], dtype=bool)

要计算b中有多少1,你不需要将数组转换为浮点数,即.astype(np.float32)可以保存,因为在python中boolean是一个子类在int 3和Python 3中你有True == 1 False == 0。所以这里是你如何计算b中有多少个:

>>> np.sum((b == 1))
3

或者:

>>> np.count_nonzero(b == 1)
3