Tensorflow:一个热门编码

时间:2015-12-14 01:59:30

标签: eval tensorflow one-hot-encoding

以下代码工作正常,但使用eval(),我觉得效率很低。有没有更好的方法来实现同样的目标?

import tensorflow as tf
import numpy as np
sess = tf.Session()
t = tf.constant([[4,5.1,6.3,5,6.5,7.2,9.3,7,1,1.4],[4,5.1,9.3,5,6.5,7.2,1.3,7,1,1.4],[4,3.1,6.3,5,6.5,3.2,5.3,7,1,1.4]])
print t
a = tf.argmax(t,1).eval(session=sess)
z = [ k==np.arange(14) for k in a]
z1 = tf.convert_to_tensor(np.asarray(z).astype('int32'))
print z1
print sess.run(z1)

输出

Tensor("Const_25:0", shape=TensorShape([Dimension(3), Dimension(10)]), dtype=float32)
Tensor("Const_26:0", shape=TensorShape([Dimension(3), Dimension(14)]), dtype=int32)
[[0 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0 0 0 0 0]]

1 个答案:

答案 0 :(得分:3)

实现它的一种方法是计算每行的最大值,然后将每个元素与该值进行比较。我没有在这台机器上安装张量流,因此无法为您提供确切的代码,但它将遵循以下内容:

z1 = tf.equal(t, tf.reduce_max(t, reduction_indices=[1], keep_dims=True))