IndexError使用1D数组(NumPy)索引2D数组

时间:2017-07-19 19:30:57

标签: python numpy one-hot-encoding

我有一个NumPy标签数组:

labels = np.ndarray(10000, dtype=np.float32)

数组中的元素如下所示:

print(labels[1:5])
Output: [ 9.  9.  4.  1.]

我想将它们转换为一个热编码标签,我使用了以下代码:

one_hot_labels = np.eye(10)[labels]

我收到以下错误:

IndexError     Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
  1 
----> 2 s=np.eye(10)[labels]

IndexError: arrays used as indices must be of integer (or boolean) type

我该如何解决这个问题?

2 个答案:

答案 0 :(得分:3)

您已将标签定义为np.float32。如果要将它们用作数组或矩阵的索引,则它们必须是整数。转换np.float32使用.astype(int)

 one_hot_labels=np.eye(10)[labels.astype(int)]

或直接将标签定义为int:

labels=np.ndarray(10000,dtype=int)

答案 1 :(得分:1)

如果labelsfloat并且您不希望更改其dtype,则可以使用MultiLabelBinarizer。这段代码应该完成工作:

from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
one_hot_labels = mlb.fit_transform(labels[:, None])