tf.one_hot中的索引是什么格式?可以是张量吗?
我有以下代码:
prediction = tf.argmax(output, axis=1)
pred_hot = tf.one_hot(indices = predictions, depth=2)
如果我运行
sess.run(prediction, feed_dict={x:batch_x, y:batch_y})
# [1 1 1 1 0 0 0 0 0] an array of either zero, one
现在,我希望这是一个双暗阵列
# [ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ]
然而,正在运行
sess.run(pred_hot, feed_dict={x:batch_x, y:batch_y})
给出错误。
所以问题是我输入到tf.one_hot函数的格式是什么,为什么它不能这样工作?
答案 0 :(得分:0)
尝试使用depth = 2,即
tf.one_hot(tensor, depth=2)
答案 1 :(得分:0)
tf.one_hot()
会将[1 1 1 1 0 0 0 0 0]
转换为[ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ]
,因此您无需将输入的输入作为2-dim数组提供。
作为一般示例,假设输入为[1 0 2]
且深度为3,则tf.one_hot()
将其转换为[ [0 1 0] [ 1 0 0] [0 0 1] ]