tensorflow one_hot输入维度

时间:2017-07-12 22:03:47

标签: tensorflow neural-network

tf.one_hot中的索引是什么格式?可以是张量吗?

我有以下代码:

prediction = tf.argmax(output, axis=1)
pred_hot = tf.one_hot(indices = predictions, depth=2)

如果我运行

sess.run(prediction, feed_dict={x:batch_x, y:batch_y})
# [1 1 1 1 0 0 0 0 0]  an array of either zero, one

现在,我希望这是一个双暗阵列

# [ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ]

然而,正在运行

sess.run(pred_hot, feed_dict={x:batch_x, y:batch_y})

给出错误。

所以问题是我输入到tf.one_hot函数的格式是什么,为什么它不能这样工作?

2 个答案:

答案 0 :(得分:0)

尝试使用depth = 2,即

tf.one_hot(tensor, depth=2) 

答案 1 :(得分:0)

tf.one_hot()会将[1 1 1 1 0 0 0 0 0]转换为[ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ],因此您无需将输入的输入作为2-dim数组提供。

作为一般示例,假设输入为[1 0 2]且深度为3,则tf.one_hot()将其转换为[ [0 1 0] [ 1 0 0] [0 0 1] ]