image_size = 28
num_labels = 10
def reformat(dataset, labels):
dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)
这条线是什么意思?
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
答案 0 :(得分:3)
在numpy中,==
运算符在比较两个numpy数组时意味着不同的东西(就像在那行中那样),所以是的,它在这个意义上是重载的。它比较了两个numpy数组,并返回一个与两个输入大小相同的布尔numpy数组。对于>=
,<
等其他比较也是如此。
E.g。
import numpy as np
print(np.array([5,8,2]) == np.array([5,3,2]))
# [True False True]
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32))
# [1. 0. 1.]
答案 1 :(得分:1)
对于Numpy数组,==
运算符是一个返回布尔数组的元素操作。 astype
函数将布尔值True
转换为1.0
,将False
转换为0.0
,如评论中所述。
答案 2 :(得分:0)
https://docs.python.org/3/reference/expressions.html#value-comparisons describes value comparisons like ==
. While the default comparison is an identity
x is y
, it first checks if either argument implements an __eq__
method. Numbers, lists, and dictionaries implement their own version. And so does numpy
.
What's unique about the numpy
__eq__
is that it does, if possible an element by element comparison, and returns a boolean array of the same size.
In [426]: [1,2,3]==[1,2,3]
Out[426]: True
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3])
In [428]: z1==z2
Out[428]: array([ True, True, True], dtype=bool)
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4])
In [433]: z1==z2
Out[433]: array([ True, True, False], dtype=bool)
In [434]: (z1==z2).astype(float) # change bool to float
Out[434]: array([ 1., 1., 0.])
A common SO question is 'why do I get this ValueError?'
In [435]: if z1==z2: print('yes')
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
That's because the comparison produces this array which has more than one True/False value.
Comparison of floats is also a common problem. Check out isclose
and allclose
it that issue comes up.