我一直在为这个热门编码问题感到困惑。我确信这是一个简单的过程,但是我已经研究了一段时间了,看不到我的错误。
我有一组形状为(1080,1)的train_label,并且有6个整数类。我正在尝试使用以下方法将其变成一个热点:
def convert_to_one_hot(train_labels_conv,classes):
Y_train = np.eye(classes)[train_labels_conv.reshape(-1)].T
return Y_train
Y_train = np.arange(6)
print(Y_train)
Y_train_hot = convert_to_one_hot(Y_train, len(Y))
print(Y_train_hot)
As a result I simply get
[0 1 2 3 4 5]
[[1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0. 1.]]
我应该不会为我的培训标签收到整个一个热门矩阵?我会感激任何朝着正确方向的指令,因为我还不习惯使用python。
答案 0 :(得分:0)
如果标签是字符串,则可以使用此功能:
import numpy as np
target = np.array(['dog', 'dog', 'cat', 'cat', 'cat', 'dog', 'dog',
'cat', 'cat', 'hamster', 'hamster'])
def one_hot(array):
unique, inverse = np.unique(array, return_inverse=True)
onehot = np.eye(unique.shape[0])[inverse]
return onehot
print(one_hot(target))
出[9]:
[[0.,1.,0.],
[0.,1.,0.],
[1.,0.,0.],
[1.,0.,0.],
[1.,0.,0.],
[0.,1.,0.],
[0.,1.,0.],
[1.,0.,0.],
[1.,0.,0.],
[0.,0.,1.],
[0.,0.,1。]])