使用numpy从图像标签进行一种热编码

时间:2019-11-27 03:32:00

标签: python numpy one-hot-encoding

我一直在为这个热门编码问题感到困惑。我确信这是一个简单的过程,但是我已经研究了一段时间了,看不到我的错误。

我有一组形状为(1080,1)的train_label,并且有6个整数类。我正在尝试使用以下方法将其变成一个热点:

def convert_to_one_hot(train_labels_conv,classes):
    Y_train = np.eye(classes)[train_labels_conv.reshape(-1)].T
    return Y_train

Y_train = np.arange(6)
print(Y_train)
Y_train_hot = convert_to_one_hot(Y_train, len(Y))
print(Y_train_hot)
As a result I simply get
[0 1 2 3 4 5]
[[1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1.]]

我应该不会为我的培训标签收到整个一个热门矩阵?我会感激任何朝着正确方向的指令,因为我还不习惯使用python。

1 个答案:

答案 0 :(得分:0)

如果标签是字符串,则可以使用此功能:

import numpy as np

target = np.array(['dog', 'dog', 'cat', 'cat', 'cat', 'dog', 'dog', 
    'cat', 'cat', 'hamster', 'hamster'])

def one_hot(array):
    unique, inverse = np.unique(array, return_inverse=True)
    onehot = np.eye(unique.shape[0])[inverse]
    return onehot

print(one_hot(target))
  

出[9]:
  [[0.,1.,0.],
         [0.,1.,0.],
         [1.,0.,0.],
         [1.,0.,0.],
         [1.,0.,0.],
         [0.,1.,0.],
         [0.,1.,0.],
         [1.,0.,0.],
         [1.,0.,0.],
         [0.,0.,1.],
         [0.,0.,1。]])