我有一个NumPy标签数组:
labels = np.ndarray(10000, dtype=np.float32)
数组中的元素如下所示:
print(labels[1:5])
Output: [ 9. 9. 4. 1.]
我想将它们转换为一个热编码标签,我使用了以下代码:
one_hot_labels = np.eye(10)[labels]
我收到以下错误:
IndexError Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
1
----> 2 s=np.eye(10)[labels]
IndexError: arrays used as indices must be of integer (or boolean) type
我该如何解决这个问题?
答案 0 :(得分:3)
您已将标签定义为np.float32
。如果要将它们用作数组或矩阵的索引,则它们必须是整数。转换np.float32
使用.astype(int)
:
one_hot_labels=np.eye(10)[labels.astype(int)]
或直接将标签定义为int:
labels=np.ndarray(10000,dtype=int)
答案 1 :(得分:1)
如果labels
为float
并且您不希望更改其dtype
,则可以使用MultiLabelBinarizer
。这段代码应该完成工作:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
one_hot_labels = mlb.fit_transform(labels[:, None])