Tensorflow One-Hot

时间:2017-01-23 00:01:02

标签: tensorflow

我是Tensorflow(和神经网络)的新手,我正在研究一个简单的分类问题。想问两个问题。

假设我有120个[1,2,3,4,5]的排列标签。在将它送入我的图表之前,我真的有必要进行One-Hot编码吗?如果是,我应该在进入tensorflow之前进行编码吗?

如果我进行单热编码,softmax预测将给出[0.001 0.202 0.321 ...... 0.002 0.0003 0.0004]。运行arg_max将生成正确的索引。我如何获得tensorflow来返回正确的标签而不是一个热门的结果?

谢谢。

1 个答案:

答案 0 :(得分:3)

所以你的输入是{1,2,3,4,5}中的120个标签(每个标签可以是从1到5的数字)?

# Your input, a 1D tensor of 120 elements from 1-5.
# Better shift your label space to 0-4 instead.
labels = labels - 1

# Now convert to a 2D tensor of 120 x 5 onehot labels.
onehot_labels = tf.one_hot(labels, 5)

# Now some computations.
....

# You end up with some onehot_output
# of the same shape as your labels (120x5).
# As you said, arg_max will give you the index of the result,
# which is a 1D index label of 120 elements.
output = tf.argmax(onehot_output, axis=1).

# You might want to shift back to {1,2,3,4,5}.
output = output + 1