多标签分类:解码一个热向量

时间:2018-10-27 02:16:42

标签: numpy machine-learning scikit-learn keras one-hot-encoding

我目前正在处理一个多标签分类问题,该问题试图对水果图像进行分类。一旦我用一种热编码转换了类别,在训练模型并想找回合适的类别后,我将如何解码?

a = np.array(['a', 'b', 'c', 'a', 'b', 'c'])
b = pandas.get_dummies(a)

样品

X_train, X_test, y_train, y_test = train_test_split(df, b, test_size=0.20, random_state=42)

......

模型训练

from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(6,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='lerelu'))

.....

预测

model.predict(X_test[0]) --->> result  ??????

1 个答案:

答案 0 :(得分:0)

如果我正确理解了您的问题,不是吗?

test_predictions = model.predict(X_test)