numpy覆盖==运算符,因为我无法理解跟随python代码

时间:2016-10-15 07:15:29

标签: python numpy

image_size = 28
num_labels = 10

def reformat(dataset, labels):
  dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
  # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
  labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
  return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)

这条线是什么意思?

labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)

代码来自https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb

3 个答案:

答案 0 :(得分:3)

在numpy中,==运算符在比较两个numpy数组时意味着不同的东西(就像在那行中那样),所以是的,它在这个意义上是重载的。它比较了两个numpy数组,并返回一个与两个输入大小相同的布尔numpy数组。对于>=<等其他比较也是如此。

E.g。

import numpy as np
print(np.array([5,8,2]) == np.array([5,3,2]))
# [True False True]
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32))
# [1. 0. 1.]

答案 1 :(得分:1)

对于Numpy数组,==运算符是一个返回布尔数组的元素操作。 astype函数将布尔值True转换为1.0,将False转换为0.0,如评论中所述。

答案 2 :(得分:0)

https://docs.python.org/3/reference/expressions.html#value-comparisons describes value comparisons like ==. While the default comparison is an identity x is y, it first checks if either argument implements an __eq__ method. Numbers, lists, and dictionaries implement their own version. And so does numpy.

What's unique about the numpy __eq__ is that it does, if possible an element by element comparison, and returns a boolean array of the same size.

In [426]: [1,2,3]==[1,2,3]
Out[426]: True
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3])
In [428]: z1==z2
Out[428]: array([ True,  True,  True], dtype=bool)
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4])
In [433]: z1==z2
Out[433]: array([ True,  True, False], dtype=bool)
In [434]: (z1==z2).astype(float)     # change bool to float
Out[434]: array([ 1.,  1.,  0.])

A common SO question is 'why do I get this ValueError?'

In [435]: if z1==z2: print('yes')
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

That's because the comparison produces this array which has more than one True/False value.

Comparison of floats is also a common problem. Check out isclose and allclose it that issue comes up.