我目前正在处理一个多标签分类问题,该问题试图对水果图像进行分类。一旦我用一种热编码转换了类别,在训练模型并想找回合适的类别后,我将如何解码?
a = np.array(['a', 'b', 'c', 'a', 'b', 'c'])
b = pandas.get_dummies(a)
样品
X_train, X_test, y_train, y_test = train_test_split(df, b, test_size=0.20, random_state=42)
......
模型训练
from keras import models
from keras import layers
model = models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(6,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='lerelu'))
.....
预测
model.predict(X_test[0]) --->> result ??????
答案 0 :(得分:0)
如果我正确理解了您的问题,不是吗?
test_predictions = model.predict(X_test)